📜  KNN算法-查找最近的邻居(1)

📅  最后修改于: 2023-12-03 15:32:28.978000             🧑  作者: Mango

KNN算法-查找最近的邻居

KNN(K-Nearest Neighbors)是一种简单但有效的监督学习算法,用于分类和回归问题。KNN算法通过判断新数据点距离已知数据点的距离来确定该数据点属于哪个分类。在分类问题中,KNN算法根据最邻近的K个数据点的超平面位置进行分类;在回归问题中,KNN算法返回最邻近的K个数据点的平均值。本文将介绍如何实现KNN算法来查找最近的邻居。

实现KNN算法

下面是一个简单的Python程序来实现KNN算法。假设我们有一个数据集,其中包含n个样本和m个特征,并且我们希望在KNN中使用欧氏距离度量。

import numpy as np

class KNN:
    def __init__(self, k=5):
        self.k = k
        
    def fit(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
        
    def predict(self, X_test):
        y_pred = []
        for x_test in X_test:
            dists = np.sum((self.X_train - x_test)**2, axis=1)
            nearest_idx = np.argsort(dists)[:self.k]
            nearest_labels = self.y_train[nearest_idx]
            y_pred.append(np.bincount(nearest_labels).argmax())
        return np.array(y_pred)
使用KNN算法

我们可以使用KNN算法来预测一个新样本的分类。假设我们有一个训练集和一个测试集,其中每个样本有两个特征。下面是一段简单的代码来使用我们的KNN类:

# 生成数据集
X_train = np.array([[1,2], [2,1], [2,3], [3,2], [1,3], [2,2], [3,3], [1,1]])
y_train = np.array([0, 0, 0, 0, 1, 1, 1, 1])
X_test = np.array([[2.5, 2]])

# 调用KNN类
knn = KNN(k=3)
knn.fit(X_train, y_train)
y_test = knn.predict(X_test)

print(y_test)

上述代码将输出预测的分类标签。

结论

KNN算法是一个简单但强大的算法,适用于分类和回归问题。使用我们在本文中介绍的Python代码,可以轻松实现KNN算法。