详解 Scikit-learn 的 neighbors.KNeighborsClassifier函数:K 近邻分类器

  • Post category:Python

sklearn.neighbors.KNeighborsClassifier 是 Scikit-learn 中一种针对 K-近邻算法的分类器,它支持二分类和多分类问题,并且具有一定的适应能力。在机器学习领域中,K-近邻算法被广泛应用于数据分类任务中,可以用于解决分类、回归、密度估计、离群点检测等各种问题。

使用 KNeighborsClassifier 进行分类时,需要提供样本特征矩阵 X 和类别标签 y 两个参数。其中,X 是一个二维数组,每行代表一个样本的特征向量;y 是一个一维数组,每个元素代表一个样本对应的类别标签。

KNeighborsClassifier 的主要参数包括:n_neighbors、weights 和 algorithm。

  • n_neighbors 表示选择的邻居数。默认情况下,它的取值为 5。
  • weights 可以是 uniformdistance,默认是 uniform。如果是 uniform,则所有邻居的权重都相同;如果是 distance,则权重是距离的倒数(即越近的样本权重越大)。
  • algorithm 可以是 ball_treekd_treebrute。默认情况下使用 auto 选项。简单来说,ball_treekd_tree 可以更快地计算出离某个点最近的 k 个点;brute 则适用于小数据集。如果要使用 ball_treekd_tree,则必须满足数据集的维度小于 20。

下面是一个简单的代码示例,演示了 KNeighborsClassifier 的基本用法:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

# 加载iris数据集
iris = load_iris()
X, y = iris.data, iris.target

# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# 构建 KNN 分类器
knn = KNeighborsClassifier()

# 拟合数据
knn.fit(X_train, y_train)

# 预测测试集的类别
y_predict = knn.predict(X_test)

# 输出预测结果
print(y_predict)

这个示例代码加载了鸢尾花数据集,并将其随机分成了训练集和测试集,然后使用 KNeighborsClassifier 构建了一个 KNN 分类器,拟合了训练数据。

另一个示例代码:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier

# 创建一个二分类数据集
X, y = make_classification(n_samples=100, n_features=2, n_redundant=0, n_clusters_per_class=1, random_state=4)

# 绘制样本散点图
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='coolwarm')
plt.title('Binary Classification Dataset')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()

# 构建 KNN 分类器
knn = KNeighborsClassifier(n_neighbors=3)

# 拟合数据
knn.fit(X, y)

# 绘制决策边界
xlim = plt.gca().get_xlim()
ylim = plt.gca().get_ylim()
xx, yy = np.meshgrid(np.linspace(xlim[0], xlim[1], 200), np.linspace(ylim[0], ylim[1], 200))
Z = knn.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.contourf(xx, yy, Z, alpha=0.4, cmap='coolwarm')
plt.xlim(xlim)
plt.ylim(ylim)
plt.show()

这个示例代码创建了一个二分类数据集,并使用 KNeighborsClassifier 构建了一个 KNN 分类器。通过绘制样本散点图和决策边界,可以直观地看出 KNN 分类器的分类效果。

以上两个示例代码可以帮助初学者更好地掌握 KNeighborsClassifier 的使用方法。