详解 Scikit-learn 的 manifold.TSNE函数:t-SNE 数据降维

  • Post category:Python

Scikit-learn是一个重要的Python机器学习库,其中的manifold.TSNE函数使用了t-SNE技术(一种非线性降维技术)来将高维数据可视化为低维数据。实际上,t-SNE可用于生成可视化图表,其中高维数据点被表示为低维空间中的点,距离与相似性信息一致。

下面提供一个使用例子:

首先,导入库和数据集:

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

iris = sns.load_dataset("iris")
X = iris.drop("species", axis=1)
y = iris["species"]

接下来,我们可以使用TSNE函数来降维可视化:

tsne = TSNE(n_components=2, random_state=0)
X_tsne = tsne.fit_transform(X)
df = pd.DataFrame({'x':X_tsne[:,0], 'y':X_tsne[:,1] ,'label':y})
sns.scatterplot(data=df, x="x", y="y", hue="label", palette="deep")
plt.show()

在上述代码中,我们使用 n_components 参数来设置降维后生成的数据集的列数。这里我们将其设置为2,以在二维平面上可视化数据集。在后面的代码中,我们使用 Seaborn 库的 scatterplot() 函数来画出生成的散点图。

下面是另一个使用 TSNE 进行数据降维和可视化的实际例子:

from sklearn.datasets import load_digits

digits = load_digits()
X, y = digits.data, digits.target

tsne = TSNE(n_components=2, random_state=0)
X_tsne = tsne.fit_transform(X)

plt.figure()
plt.scatter(X_tsne[:,0], X_tsne[:,1], c=y, cmap='rainbow')
plt.colorbar()
plt.show()

在这个例子里,我们使用 digits 数据集,其中包含了手写数字的图像。使用 TSNE 函数,我们将这些图像经过降维转换为二维平面的点集。每个点的颜色代表了其对应的数字类别。

这里,我们使用 matplotlib 的 scatter() 函数来画出散点图,并使用 colorbar() 函数来画出数字与颜色之间的对应关系。

总之,Scikit-learn的manifold.TSNE函数是一个非常强大的数据降维技术,能够将高维数据转化为低维数据,并生成可视化的表现形式来更好地展示数据的特征和内在结构。