23 手写数字识别
在这一章节中,我们将使用 Scikit-Learn 框架进行手写数字识别的案例分析。手写数字识别是一个经典的机器学习问题,通常用于测试和展示分类算法的效果。我们将通过使用 MNIST 数据集来实现这一任务,MNIST 是一个包含 70000 张手写数字图片的数据集,广泛用于机器学习标准基准测试。
数据集概述
MNIST 数据集包含 0 到 9 的手写数字,每张图片为 28x28 像素。我们将使用 sklearn
中的 datasets
模块来加载这个数据集。
数据的加载
首先,我们需要安装 scikit-learn
和 matplotlib
,后者将用于可视化数据。
pip install scikit-learn matplotlib
接下来,我们可以通过以下代码加载和查看数据:
import matplotlib.pyplot as plt
from sklearn import datasets
# 加载手写数字数据集
digits = datasets.load_digits()
# 查看数据集的基本信息
print("数据集的大小:", digits.data.shape)
print("类别标签:", digits.target)
# 可视化第一张手写数字
plt.imshow(digits.images[0], cmap='gray')
plt.title(f"手写数字: {digits.target[0]}")
plt.axis('off')
plt.show()
在上述代码中,我们首先加载手写数字数据集,并输出数据集的大小和类别标签。然后,我们展示了数据集中第一个手写数字的图像。
数据预处理
在进行分类任务之前,我们需要对数据进行一些预处理。在这里,我们将手写数字的每个图片展开为一个 64 维的特征向量,同时进行训练集和测试集的划分,以09:01的比例进行。
from sklearn.model_selection import train_test_split
# 将图像数据展开为特征向量
X = digits.data
y = digits.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
选择分类算法
在这个案例中,我们将使用 KNeighborsClassifier
(K近邻分类器)作为我们的分类算法。K近邻算法是一种简单而直观的分类算法,适合用于手写数字等图像分类任务。
模型训练
接下来,我们创建 K近邻分类器并在训练集上进行训练:
from sklearn.neighbors import KNeighborsClassifier
# 创建K近邻分类器实例
knn = KNeighborsClassifier(n_neighbors=3)
# 使用训练集进行训练
knn.fit(X_train, y_train)
模型评估
模型训练完成后,我们需要评估其在测试集上的性能。我们将计算模型的准确率。
from sklearn.metrics import accuracy_score
# 在测试集上进行预测
y_pred = knn.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("模型准确率: {:.2f}%".format(accuracy * 100))
通过这段代码,我们可以得到模型的准确率,帮助我们验证模型的有效性。
可视化分类结果
为了更好地理解模型的性能,我们可以可视化一些预测结果。我们将绘制真实标签和模型预测标签:
# 可视化部分测试结果
plt.figure(figsize=(12, 6))
for index in range(10):
plt.subplot(2, 10, index + 1)
plt.imshow(X_test[index].reshape(8, 8), cmap='gray')
plt.title(f'真实: {y_test[index]}\n预测: {y_pred[index]}')
plt.axis('off')
plt.show()
这里的代码将绘制第一排为真实标签的手写数字,第二排为模型的预测结果。在真实预测与模型输出之间进行比较,能够直观地看到模型的性能。
总结与展望
在这一节中,我们使用 Scikit-Learn 实现了一个简单的手写数字识别模型。我们从数据加载、预处理,到模型训练和评估,每一步都进行了详细的说明。通过 K近邻算法,我们取得了令人满意的准确率。
接下来,我们将在后续的章节中探索更复杂的机器学习任务,比如客户分群。这将带来不同的挑战,帮助我们深入理解机器学习的多样性和应用。
在手写数字识别中,我们还可以考虑使用更复杂的模型,比如随机森林、支持向量机 (SVM) 或神经网络,以进一步提高性能。
通过这些探索,我们将不断加深对 Scikit-Learn 框架的理解和应用能力。