详解 Scikit-learn 的 metrics.confusion_matrix函数:计算混淆矩阵

  • Post category:Python

Scikit-learn是一个Python机器学习库,其中包含许多评估分类性能指标的功能。其中一个重要的函数就是sklearn.metrics.confusion_matrix函数。该函数可以计算混淆矩阵,用于评估分类模型的准确性。

作用

sklearn.metrics.confusion_matrix函数的作用是计算分类模型的混淆矩阵。混淆矩阵是一种评估分类模型预测效果的矩阵,它显示了分类模型预测结果与实际标签之间的关系。

混淆矩阵包含四个元素:真正例(TP)、假正例(FP)、真反例(TN)和假反例(FN),它们表示预测结果和实际标签之间的四种情况。

混淆矩阵计算公式为:

[[TN  FP]
 [FN  TP]]

使用方法

sklearn.metrics.confusion_matrix函数的参数如下:

confusion_matrix(y_true, y_pred, labels=None, sample_weight=None, normalize=None)

其中:

  • y_true:array-like,真实标签
  • y_pred:array-like,预测结果
  • labels:array-like,标签列表
  • sample_weight:array-like,每个样本的权重
  • normalize:str,是否对混淆矩阵进行归一化计算

返回:

  • C:array,混淆矩阵

使用示例1:

from sklearn.metrics import confusion_matrix

y_true = [1, 0, 1, 1, 0, 1]
y_pred = [0, 0, 1, 1, 0, 1]

confusion_matrix(y_true, y_pred)

输出:

array([[2, 0],
       [1, 3]])

在这个例子中,真实标签是1、0、1、1、0、1,预测结果是0、0、1、1、0、1。混淆矩阵显示了预测结果与实际标签之间的四种情况,其中有2个真负例、3个真正例、1个假反例和0个假正例。

使用示例2:

from sklearn.metrics import confusion_matrix

y_true = [0, 1, 0, 1, 0, 1]
y_pred = [1, 1, 1, 0, 0, 1]

labels = [0, 1]

confusion_matrix(y_true, y_pred, labels=labels)

输出:

array([[1, 2],
       [1, 2]])

在这个例子中,真实标签是0、1、0、1、0、1,预测结果是1、1、1、0、0、1。标签列表是[0,1],混淆矩阵显示了预测结果与实际标签之间的四种情况,其中有1个真负例、2个真正例、1个假反例和2个假正例。

总结

sklearn.metrics.confusion_matrix函数是用于计算分类模型混淆矩阵的函数,用于评估分类模型预测效果。在实际使用时,需要输入真实标签和预测结果,并可以选择标签列表及计算归一化的方式。混淆矩阵可以帮助我们通过对预测结果和真实标签的四种情况进行评估和比较,来评估分类模型的准确性。