📅  最后修改于: 2023-12-03 14:47:55.591000             🧑  作者: Mango
Tensorflow.js tf.metrics.categoricalAccuracy()
函数是用于计算多分类问题中模型的准确率(Accuracy)的函数。
在计算准确率时,将模型的预测与真实的标签进行比较,统计有多少比例的样本预测正确。
函数的输入有两个张量:预测值和真实标签。其中,预测值通常是模型预测出来的结果,真实标签是样本真实的分类标签。
函数的输出是一个标量张量,表示模型的准确率。
下面给出一个使用tf.metrics.categoricalAccuracy()
函数计算准确率的示例代码:
// 导入Tensorflow.js库
import * as tf from '@tensorflow/tfjs';
// 创建预测值和真实标签张量
const predictions = tf.tensor2d([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.3, 0.7]]);
const labels = tf.tensor2d([[1, 0], [0, 1], [1, 0], [0, 1]]);
// 计算准确率
const acc = tf.metrics.categoricalAccuracy(predictions, labels);
// 打印准确率
acc.print();
上述代码中,预测值张量有4个样本,每个样本有2个输出,表示2个类别的概率值。真实标签也有4个样本,每个样本也有2个输出,表示样本的类别标签。
函数的执行结果是一个标量张量,打印出来的结果是所有样本的平均准确率。
在使用tf.metrics.categoricalAccuracy()
函数时,需要注意以下几点:
预测值和真实标签的形状必须一致。
预测值和真实标签的取值必须都是0或1。
函数的输入张量可以是CPU或GPU张量,但是函数返回的标量张量是CPU张量。