📜  svm 分类器 sklearn (1)

📅  最后修改于: 2023-12-03 15:35:12.201000             🧑  作者: Mango

SVM分类器(sklearn)

简介

SVM全称支持向量机(Support Vector Machine),是一种二分类模型。其基本思想是将实例映射到高维空间中,然后在高维空间中找到一个最优超平面,使得不同类别的实例能够被分开。SVM分类器常用于高维数据分类和回归分析。

sklearn是Python中用于数据挖掘和数据分析的开源模块,提供了各种机器学习算法的实现。

在sklearn中,SVM分类器实现在svm模块中。

安装

如果你还没有安装sklearn模块,可以通过以下命令进行安装:

!pip install -U scikit-learn
使用
加载数据集

我们先加载一个样例数据集,iris花卉数据集:

from sklearn.datasets import load_iris

iris = load_iris()
X, y = iris.data, iris.target
分割数据集

将数据分为训练集和测试集:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
训练模型

使用svm.SVC类实例化并训练模型:

from sklearn import svm

clf = svm.SVC(kernel='linear', C=1, gamma='scale')
clf.fit(X_train, y_train)
预测

使用训练好的分类器预测测试集:

y_pred = clf.predict(X_test)
评估模型

使用以下指标评估模型的性能:

from sklearn.metrics import accuracy_score, confusion_matrix

print("Accuracy:", accuracy_score(y_test, y_pred))
print("Confusion Matrix:", confusion_matrix(y_test, y_pred))

以上指标分别是准确率(Accuracy)和混淆矩阵(Confusion Matrix)。

参数解释

SVM分类器常用的参数有以下几个:

  • kernel:核函数,用于将原始空间映射到高维特征空间,常用的有线性核(linear)、多项式核(poly)、径向基核(rbf)等。
  • C:误差项的惩罚系数,用于控制分类边界的平滑程度。
  • gamma:核函数的系数,一般只对多项式核函数和径向基核函数有效。
总结

SVM分类器是一种常用的机器学习算法,可以处理高维数据的分类和回归问题。而sklearn则是Python中实现机器学习算法的重要模块之一,在sklearn中可以轻松地实现SVM分类器。