📅  最后修改于: 2023-12-03 15:04:10.939000             🧑  作者: Mango
在机器学习中,混淆矩阵是一种用于评估分类模型性能的常用工具。TensorFlow提供了tf.math.confusion_matrix()函数,用于计算分类模型的混淆矩阵矩阵。本文将对该函数进行介绍。
tf.math.confusion_matrix(
labels, predictions, num_classes=None, weights=None, dtype=tf.dtypes.int32,
name=None
)
混淆矩阵张量,形状为[num_classes, num_classes]。矩阵的行表示实际类别,列表示预测类别。
我们来看一个例子:
import tensorflow as tf
import numpy as np
labels = np.array([0, 1, 2, 3, 4])
predictions = np.array([3, 3, 1, 1, 4])
confusion_mat = tf.math.confusion_matrix(labels, predictions, num_classes=5)
print(confusion_mat)
输出结果为:
tf.Tensor(
[[0 0 0 1 0]
[0 1 0 1 0]
[0 0 0 0 0]
[0 0 0 1 0]
[0 0 0 0 1]], shape=(5, 5), dtype=int32)
我们可以通过以下方式解读矩阵:
在分类任务中,混淆矩阵是评估模型性能的重要工具。使用TensorFlow中的tf.math.confusion_matrix()函数可以方便地计算混淆矩阵。