Scikit-learn 是一款流行的机器学习库,提供了多种聚类算法,其中之一就是 MeanShift。
MeanShift 算法介绍
MeanShift 是一种基于密度的聚类算法,可以根据样本点的密度分布,自动确定聚类中心的个数,且支持非球形簇的聚类。它的实现方式是,从一个样本点随机出发,计算样本点在以自身为圆心半径为 bandwidth 的球形区域内,所有点的均值位置,并将当前样本点移动到该位置。然后重复以上过程直到收敛。
sklearn.cluster.MeanShift 函数使用方法
sklearn.cluster.MeanShift 函数实现了 MeanShift 聚类算法,函数的调用方式如下:
from sklearn.cluster import MeanShift
clt = MeanShift(bandwidth=2)
clt.fit(data)
其中,bandwidth 表示球形区域的半径,默认值为1。data 表示用于聚类的数据特征矩阵。
MeanShift 函数的返回值
sklearn.cluster.MeanShift 函数的 fit 方法会返回以下两个属性:
cluster_centers_
:聚类中心点的坐标labels_
:每个样本点的聚类标签
MeanShift 使用例子
例子一:使用 MeanShift 聚类鸢尾花数据集
我们使用 sklearn 自带的鸢尾花数据集,首先导入相关的包,以及数据集。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
然后我们使用 MeanShift 聚类算法对数据进行聚类。
clt = MeanShift()
clt.fit(X)
最后我们将聚类结果可视化。
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'w']
for i in range(len(X)):
plt.scatter(X[i, 0], X[i, 1], color=colors[clt.labels_[i]])
plt.show()
运行后会出现一个散点图,其中不同颜色的点表示不同的聚类结果。
例子二:使用 MeanShift 聚类手写数字数据集
我们使用 sklearn 自带的手写数字数据集,首先导入相关的包,以及数据集。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift
from sklearn.datasets import load_digits
digits = load_digits()
X = digits.data
y = digits.target
然后我们使用 MeanShift 聚类算法对数据进行聚类。
clt = MeanShift()
clt.fit(X)
最后我们将聚类结果可视化。
labels = np.unique(clt.labels_)
for label in labels:
plt.figure()
plt.title("cluster %d" % label)
for i in range(len(X)):
if clt.labels_[i] == label:
plt.subplot(3, 5, i + 1)
plt.imshow(X[i].reshape(8, 8), cmap=plt.cm.gray_r)
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
运行后会出现多个子图,每个子图表示一个类,其中包含了几个之前未分类的手写数字。