详解 Scikit-learn 的 datasets.load_digits函数:加载手写数字数据集

  • Post category:Python

sklearn.datasets.load_digits 函数是 scikit-learn 中用于加载手写数字数据集的函数。该数据集包含 1797 张 8×8 的手写数字图片,是一个经典的机器学习数据集,广泛用于测试分类算法的性能。

该函数返回一个代表手写数字数据集的 Python Bunch 类型对象,包含以下属性:

  • data: 一个 $n \times m$ 的 numpy 数组,表示数据集的特征矩阵。其中 $n$ 是数据集样本数量,$m$ 是每个样本的特征数(即图像的像素数)。
  • target: 一个长度为 $n$ 的 numpy 数组,表示每个样本的真实标签(即图像中手写数字的真实值)。
  • target_names: 一个长度为 $k$ 的 numpy 数组,表示可能的标签值。注意,手写数字数据集中可能的标签值是 0 到 9,因此 target_names 是一个长度为 10 的数组。
  • images: 一个 $n \times h \times w$ 的 numpy 数组,表示原始图像。其中 $n$ 是样本数量,$h$ 和 $w$ 是图像的高度和宽度,即8×8。

下面展示该数据集的加载方法和数据集探索的实例:

from sklearn.datasets import load_digits
import matplotlib.pyplot as plt

# 加载手写数字数据集
digits = load_digits()

# 打印数据集的描述信息
print(digits.DESCR)

# 打印数据集的特征和标签
print(digits.data)
print(digits.target)

# 打印数据集的维度信息
print(digits.data.shape)
print(digits.target.shape)

# 打印数据集的图片
n_samples = 5
fig, axes = plt.subplots(n_samples, 5, figsize=(10, 10), 
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(n_samples):
    for j in range(5):
        axes[i, j].imshow(digits.images[i * 5 + j], cmap='gray_r')
        axes[i, j].set_title(digits.target[i * 5 + j])

plt.show()

上述代码中,首先加载手写数字数据集,并打印出数据集的描述信息,特征矩阵和标签。其中特征矩阵表示每张图像的像素值,标签表示每张图像中手写数字的真实值。然后,打印出数据集的维度信息,以便了解数据集的规模和样本个数。最后,用 matplotlib 库绘制出前 5 个样本的图像,以便观察手写数字数据集的样本情况。

另外,下面还展示了一个使用手写数字数据集进行分类的实例:

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# 加载手写数字数据集
digits = load_digits()

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(digits.data, 
                                                    digits.target, 
                                                    test_size=0.3, 
                                                    random_state=0)

# 训练逻辑回归模型
lr = LogisticRegression(solver='liblinear', multi_class='auto')
lr.fit(X_train, y_train)

# 在测试集上进行预测,并计算准确率
y_pred = lr.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print('Accuracy:', accuracy)

上述代码中,首先加载手写数字数据集,并使用 train_test_split 函数将数据集划分为训练集和测试集,其中测试集占数据集的 30%。然后,使用逻辑回归算法训练分类器,并在测试集上进行预测,计算分类器的准确率。该代码实现了一种简单的手写数字分类器,并展示了 scikit-learn 中加载手写数字数据集的基本方法。