sklearn.datasets.load_digits
函数是 scikit-learn 中用于加载手写数字数据集的函数。该数据集包含 1797 张 8×8 的手写数字图片,是一个经典的机器学习数据集,广泛用于测试分类算法的性能。
该函数返回一个代表手写数字数据集的 Python Bunch
类型对象,包含以下属性:
data
: 一个 $n \times m$ 的 numpy 数组,表示数据集的特征矩阵。其中 $n$ 是数据集样本数量,$m$ 是每个样本的特征数(即图像的像素数)。target
: 一个长度为 $n$ 的 numpy 数组,表示每个样本的真实标签(即图像中手写数字的真实值)。target_names
: 一个长度为 $k$ 的 numpy 数组,表示可能的标签值。注意,手写数字数据集中可能的标签值是 0 到 9,因此target_names
是一个长度为 10 的数组。images
: 一个 $n \times h \times w$ 的 numpy 数组,表示原始图像。其中 $n$ 是样本数量,$h$ 和 $w$ 是图像的高度和宽度,即8×8。
下面展示该数据集的加载方法和数据集探索的实例:
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
# 加载手写数字数据集
digits = load_digits()
# 打印数据集的描述信息
print(digits.DESCR)
# 打印数据集的特征和标签
print(digits.data)
print(digits.target)
# 打印数据集的维度信息
print(digits.data.shape)
print(digits.target.shape)
# 打印数据集的图片
n_samples = 5
fig, axes = plt.subplots(n_samples, 5, figsize=(10, 10),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(n_samples):
for j in range(5):
axes[i, j].imshow(digits.images[i * 5 + j], cmap='gray_r')
axes[i, j].set_title(digits.target[i * 5 + j])
plt.show()
上述代码中,首先加载手写数字数据集,并打印出数据集的描述信息,特征矩阵和标签。其中特征矩阵表示每张图像的像素值,标签表示每张图像中手写数字的真实值。然后,打印出数据集的维度信息,以便了解数据集的规模和样本个数。最后,用 matplotlib 库绘制出前 5 个样本的图像,以便观察手写数字数据集的样本情况。
另外,下面还展示了一个使用手写数字数据集进行分类的实例:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
# 加载手写数字数据集
digits = load_digits()
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(digits.data,
digits.target,
test_size=0.3,
random_state=0)
# 训练逻辑回归模型
lr = LogisticRegression(solver='liblinear', multi_class='auto')
lr.fit(X_train, y_train)
# 在测试集上进行预测,并计算准确率
y_pred = lr.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print('Accuracy:', accuracy)
上述代码中,首先加载手写数字数据集,并使用 train_test_split
函数将数据集划分为训练集和测试集,其中测试集占数据集的 30%。然后,使用逻辑回归算法训练分类器,并在测试集上进行预测,计算分类器的准确率。该代码实现了一种简单的手写数字分类器,并展示了 scikit-learn 中加载手写数字数据集的基本方法。