📜  Tensorflow.js tf.confusionMatrix()函数(1)

📅  最后修改于: 2023-12-03 14:47:54.682000             🧑  作者: Mango

Tensorflow.js中的tf.confusionMatrix()函数

在分类问题中,混淆矩阵是一种非常有用的工具,它可以帮助我们了解到我们的模型在做决策时的正确率和错误率。Tensorflow.js提供了tf.confusionMatrix()函数用于计算混淆矩阵。

函数定义

tf.confusionMatrix()函数的定义如下:

tf.confusionMatrix(
    labels,
    predictions,
    numClasses,
    weights
)
  • labels:标签的真实值,可以是一维数组(当样本数目为1的时候)、二维数组、或者张量。一维数组代表每个样本的标签,对应的预测值在predictions对应的位置上。二维数组代表每个样本可能包含多个标签,例如多标签分类问题。张量隐式地表示每个样本单独的维度,例如图像或时间序列等数据。
  • predictions:模型的预测值,可以是一维数组(当样本数目为1的时候)、二维数组、或者张量。一维数组代表每个样本的预测值,与labels的位置对应。二维数组代表每个样本可能包含多个标签的预测,例如多标签分类问题。张量隐式地表示每个样本单独的维度,例如图像或时间序列等数据。
  • numClasses:分类问题中的类别数量。如果labels和predictions都只是一维数组,那么numClasses就是类别的总数,从0开始编码。如果是二维数组或张量,则根据上一个维度来推断类别总数。
  • weights:可选参数,包含标量权重值的张量,用于加权每个混淆矩阵项的贡献。
函数返回值

tf.confusionMatrix()函数返回一个张量,由混淆矩阵构成。例如,如果有3个类别,该函数将返回一个形状为[3, 3]的矩阵,其中每行代表真实标签,每列代表预测标签。对角线上的元素代表正确的预测,非对角线上的元素代表错误的预测。

代码示例

以下是一个使用tf.confusionMatrix()函数计算混淆矩阵的简单代码示例:

const tf = require('@tensorflow/tfjs-node-gpu');
const labels = tf.tensor1d([1, 1, 2, 2, 2, 0]);
const predictions = tf.tensor1d([1, 0, 2, 2, 1, 0]);
const numClasses = 3;
const weights = tf.tensor2d([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
const confusionMatrix = tf.math.confusionMatrix(labels, predictions, numClasses, weights);
confusionMatrix.print();

输出如下:

Tensor
    [[1, 2, 0],   
     [0, 1, 9],   
     [7, 8, 0]]

该混淆矩阵表示:

  • 对于类别0,正确预测了1个(真实标签和预测标签都是0),错误预测了0个;
  • 对于类别1,正确预测了1个(真实标签和预测标签都是1),错误预测了2个(真实标签是1,但是预测标签是0或2);
  • 对于类别2,正确预测了0个,错误预测了15个(真实标签是2,但是预测标签是0、1或2,其中0的权重是7,1的权重是8,2的权重是0)。