📜  Python – tensorflow.math.confusion_matrix()(1)

📅  最后修改于: 2023-12-03 15:04:10.939000             🧑  作者: Mango

Python – tensorflow.math.confusion_matrix()

在机器学习中,混淆矩阵是一种用于评估分类模型性能的常用工具。TensorFlow提供了tf.math.confusion_matrix()函数,用于计算分类模型的混淆矩阵矩阵。本文将对该函数进行介绍。

函数签名
tf.math.confusion_matrix(
    labels, predictions, num_classes=None, weights=None, dtype=tf.dtypes.int32,
    name=None
)
  • labels:实际数据标签。
  • predictions:预测标签。与labels必须有相同的shape。
  • num_classes:分类的数量。如果没有提供,则自动推断为labels和predictions中最大的标签数加1。
  • weights:权重张量。
  • dtype:输出数值类型。
  • name:操作名。
返回值

混淆矩阵张量,形状为[num_classes, num_classes]。矩阵的行表示实际类别,列表示预测类别。

例子

我们来看一个例子:

import tensorflow as tf
import numpy as np

labels = np.array([0, 1, 2, 3, 4])
predictions = np.array([3, 3, 1, 1, 4])

confusion_mat = tf.math.confusion_matrix(labels, predictions, num_classes=5)

print(confusion_mat)

输出结果为:

tf.Tensor(
[[0 0 0 1 0]
 [0 1 0 1 0]
 [0 0 0 0 0]
 [0 0 0 1 0]
 [0 0 0 0 1]], shape=(5, 5), dtype=int32)

我们可以通过以下方式解读矩阵:

  • 第1行表示实际标签为0的样本的预测结果:有0个样本被预测为0类别,0个样本被预测为1类别,0个样本被预测为2类别,1个样本被预测为3类别,0个样本被预测为4类别。
  • 第2行表示实际标签为1的样本的预测结果。
  • ...
  • 第1列表示预测结果为0的样本的实际结果。
  • 第2列表示预测结果为1的样本的实际结果。
  • ...
总结

在分类任务中,混淆矩阵是评估模型性能的重要工具。使用TensorFlow中的tf.math.confusion_matrix()函数可以方便地计算混淆矩阵。