📜  Scikit学习-K最近邻居(KNN)(1)

📅  最后修改于: 2023-12-03 15:05:05.400000             🧑  作者: Mango

Scikit学习-K最近邻居(KNN)

介绍

KNN是一种非常简单的机器学习算法,它可以用于分类和回归问题。在分类问题中,KNN算法会将未知数据点分类为和它最接近的已知数据点所属的类别。在回归问题中,KNN算法会将未知数据点的值设置为和它最接近的已知数据点的平均值。KNN算法没有训练过程,因为它会直接用已知数据点来做分类或回归。

使用Scikit实现KNN算法

我们可以使用Scikit的neighbors模块来实现KNN算法。下面是一个简单的例子:

from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 将数据集分成训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建KNN分类器,将k设为3
knn = KNeighborsClassifier(n_neighbors=3)

# 在训练集上拟合模型
knn.fit(X_train, y_train)

# 用测试集数据进行预测
y_pred = knn.predict(X_test)

# 输出预测的准确率
print("Accuracy:", knn.score(X_test, y_test))

在这个示例中,我们首先加载Iris花数据集,并将其分为训练集和测试集。然后,我们创建了一个KNN分类器对象,并将k值设为3,因为我们希望找到3个最接近的训练数据点来做预测。接下来,我们在训练集上拟合模型,然后用测试集数据进行预测。最后,我们输出模型的准确率。在这个示例中,我们得到的准确率为97.78%。

调整KNN算法的参数

KNN算法有一个参数k,它决定了我们要找多少个最接近的训练数据点来做预测。通常情况下,k值越小,模型就越复杂。当k等于1时,模型最复杂,因为每个预测值都会直接使用最接近的训练数据点的标签。当k增大时,模型变得更简单,因为容忍更多的错误。当k等于训练集中的样本数时,模型最简单,因为它会始终输出训练集中最常见的标签。

我们可以使用交叉验证来调整KNN算法的参数。下面是一个简单的示例:

from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 创建KNN分类器对象
knn = KNeighborsClassifier()

# 使用5折交叉验证来调整k的值
k_range = range(1, 31)
k_scores = []
for k in k_range:
    knn.n_neighbors = k
    scores = cross_val_score(knn, X, y, cv=5, scoring='accuracy')
    k_scores.append(scores.mean())

# 输出交叉验证得分
print(k_scores)

在这个示例中,我们首先加载Iris花数据集。然后,我们创建了一个KNN分类器对象,它的k值默认为5。我们用了k_range这个list来存储k的值,这里我们选择了1到30这30个值。接着,我们用交叉验证来计算每个k值的得分,将得分存储在k_scores这个list中。最后,我们输出这个list,以便比较不同k值得分的大小。

结论

KNN算法是一个非常简单的机器学习算法,但通常表现出色。Scikit的neighbors模块提供了很多功能来实现KNN算法,并且非常易于使用。通过调整k的值来优化模型是一种常用的方法,可以通过交叉验证计算得分来进行优化。