📜  多类分类的一对一策略

📅  最后修改于: 2022-05-13 01:55:22.599000             🧑  作者: Mango

多类分类的一对一策略

先决条件:分类入门/

分类可能是最常见的机器学习任务。在我们深入了解One-vs-Rest (OVR)分类器是什么以及它们如何工作之前,您可以点击下面的链接并简要了解分类是什么以及它是如何有用的。

一般来说,有两种分类算法:

  1. 二进制分类算法。
  2. 多类分类算法。

二进制分类是我们必须将对象分为两组。通常,这两组由“真”和“假”组成。例如,给定一组特定的健康属性,二元分类任务可能是确定一个人是否患有糖尿病。

另一方面,在多类分类中,有两个以上的类。例如,给定水果的一组属性,如形状和颜色,多类分类任务将是确定水果的类型。

所以,既然您已经了解了二元和多类分类的工作原理,那么让我们继续了解如何使用 one-vs-rest 启发式方法。

一对一(OVR)方法
许多流行的分类算法都是为二进制分类问题而设计的。这些算法包括:

  • 逻辑回归
  • 支持向量机 (SVM)
  • 感知器模型

还有很多。

因此,这些流行的分类算法不能直接用于多类分类问题。一些启发式方法可以将多类分类问题分解为许多不同的二元分类问题。为了理解它是如何工作的,让我们考虑一个例子:比如说,一个分类问题是将各种水果分类为三种水果:香蕉、橙子或苹果。现在,这显然是一个多类分类问题。如果你想使用二进制分类算法,比如 SVM。 One-vs-Rest 方法处理此问题的方式如下所示:

由于分类问题中存在三个类,One-vs-Rest 方法会将这个问题分解为三个二元分类问题:

  • 问题 1:香蕉 vs [橙子,苹果]
  • 问题 2 : Orange vs [Banana, Apple]
  • 问题 3:苹果 vs [香蕉,橙子]

因此,不是将其解决为(香蕉 vs 橙子 vs 苹果),而是使用如上所示的三个二元分类问题来解决。

这种方法的一个主要缺点或缺点是必须创建许多模型。对于具有“n”个类的多类问题,必须创建“n”个模型,这可能会减慢整个过程。但是,它对于具有少量类的数据集非常有用,我们希望使用像 SVM 或 Logistic 回归这样的模型。

使用 Python3 实现 One-vs-Rest 方法
Python 的 scikit-learn 库提供了一个方法 OneVsRestClassifier(estimator, *, n_jobs=None) 来实现这个方法。对于此实施,我们将使用流行的“葡萄酒数据集”,使用化学属性确定葡萄酒的来源。我们可以使用 scikit-learn 来引导这个数据集。要了解有关此数据集的更多信息,您可以使用以下链接:Wine Dataset

我们将使用支持向量机,这是一种二元分类算法,并将其与 One-vs-Rest 启发式算法一起使用来执行多类分类。

为了评估我们的模型,我们将看到测试集的准确度得分和模型的分类报告。

from sklearn.datasets import load_wine
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
import warnings
   
''' 
We are ignoring warnings because of a peculiar fact about this
dataset. The 3rd label, 'Label2' is never predicted and so the python 
interpreter throws a warning. However, this can safely be ignored because 
we are not concerned if a certain label is predicted or not 
'''
warnings.filterwarnings('ignore')
   
# Loading the dataset
dataset = load_wine()
X = dataset.data
y = dataset.target
   
# Splitting the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size = 0.1, random_state = 13)
   
# Creating the SVM model
model = OneVsRestClassifier(SVC())
   
# Fitting the model with training data
model.fit(X_train, y_train)
   
# Making a prediction on the test set
prediction = model.predict(X_test)
   
# Evaluating the model
print(f"Test Set Accuracy : {accuracy_score(
    y_test, prediction) * 100} %\n\n")
print(f"Classification Report : \n\n{classification_report(
    y_test, prediction)}")

输出:

Test Set Accuracy : 66.66666666666666 %

Classification Report : 

              precision    recall  f1-score   support

           0       0.62      1.00      0.77         5
           1       0.70      0.88      0.78         8

   micro avg       0.67      0.92      0.77        13
   macro avg       0.66      0.94      0.77        13
weighted avg       0.67      0.92      0.77        13

我们得到大约 66.667% 的测试集准确率。这对这个数据集来说还不错。该数据集因难以分类而臭名昭著,基准准确度为 62.4 +- 0.4 %。所以,我们的结果实际上是相当不错的。

结论:
现在您已经知道如何使用 One-vs-Rest 启发式方法使用二元分类器执行多类分类,您可以在下次必须执行一些多类分类任务时尝试使用它。