Tensorflow.js tf.confusionMatrix()函数
Tensorflow.js 是一个由谷歌开发的开源库,用于在浏览器或节点环境中运行机器学习模型以及深度学习神经网络。
.confusionMatrix()函数用于根据所述真实标签和预测标签计算混淆矩阵。
句法:
tf.confusionMatrix(labels, predictions, numClasses)
参数:
- 标签:指定的目标标签应该是基于零的整数,有利于类。它的形状为 [numExamples]。其中, numExamples是合并实例的度量。它可以是 tf.Tensor1D、TypedArray 或数组类型。
- 预测:它是声明的预测类别,应该是基于零的整数,有利于类别。它应该具有与所述标签相同的形状。它可以是 tf.Tensor1D、TypedArray 或数组类型。
- numClasses:整数类型的总类数。此外,它的度量应该大于所述标签和预测中的最大元素。它是数字类型。
返回值:返回 tf.Tensor2D 对象。
示例 1:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining predictions, labels and
// numClasses
const lab = tf.tensor1d([3, 4, 1, 0, 1], 'int32');
const pred = tf.tensor1d([1, 3, 0, 4, 1], 'int32');
const num_Cls = 2;
// Calling tf.confusionMatrix() method
const output = tf.math.confusionMatrix(lab, pred, num_Cls);
// Printing output
output.print();
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling tf.confusionMatrix() method
const res = tf.math.confusionMatrix(
tf.tensor1d([3.3, 4.5, null, 'a', 'b']),
tf.tensor1d([-2, 5.3, -0.1, 4.3, 12.5]), 4
);
// Printing output
res.print();
输出:
Tensor
[[0, 0],
[1, 1]]
示例 2:
Javascript
// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Calling tf.confusionMatrix() method
const res = tf.math.confusionMatrix(
tf.tensor1d([3.3, 4.5, null, 'a', 'b']),
tf.tensor1d([-2, 5.3, -0.1, 4.3, 12.5]), 4
);
// Printing output
res.print();
输出:
Tensor
[[1, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]
参考: https://js.tensorflow.org/api/latest/#confusionMatrix