23 手写数字识别

在这一章节中,我们将使用 Scikit-Learn 框架进行手写数字识别的案例分析。手写数字识别是一个经典的机器学习问题,通常用于测试和展示分类算法的效果。我们将通过使用 MNIST 数据集来实现这一任务,MNIST 是一个包含 70000 张手写数字图片的数据集,广泛用于机器学习标准基准测试。

数据集概述

MNIST 数据集包含 0 到 9 的手写数字,每张图片为 28x28 像素。我们将使用 sklearn 中的 datasets 模块来加载这个数据集。

数据的加载

首先,我们需要安装 scikit-learnmatplotlib,后者将用于可视化数据。

1
pip install scikit-learn matplotlib

接下来,我们可以通过以下代码加载和查看数据:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import matplotlib.pyplot as plt
from sklearn import datasets

# 加载手写数字数据集
digits = datasets.load_digits()

# 查看数据集的基本信息
print("数据集的大小:", digits.data.shape)
print("类别标签:", digits.target)

# 可视化第一张手写数字
plt.imshow(digits.images[0], cmap='gray')
plt.title(f"手写数字: {digits.target[0]}")
plt.axis('off')
plt.show()

在上述代码中,我们首先加载手写数字数据集,并输出数据集的大小和类别标签。然后,我们展示了数据集中第一个手写数字的图像。

数据预处理

在进行分类任务之前,我们需要对数据进行一些预处理。在这里,我们将手写数字的每个图片展开为一个 64 维的特征向量,同时进行训练集和测试集的划分,以09:01的比例进行。

1
2
3
4
5
6
7
8
from sklearn.model_selection import train_test_split

# 将图像数据展开为特征向量
X = digits.data
y = digits.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

选择分类算法

在这个案例中,我们将使用 KNeighborsClassifier(K近邻分类器)作为我们的分类算法。K近邻算法是一种简单而直观的分类算法,适合用于手写数字等图像分类任务。

模型训练

接下来,我们创建 K近邻分类器并在训练集上进行训练:

1
2
3
4
5
6
7
from sklearn.neighbors import KNeighborsClassifier

# 创建K近邻分类器实例
knn = KNeighborsClassifier(n_neighbors=3)

# 使用训练集进行训练
knn.fit(X_train, y_train)

模型评估

模型训练完成后,我们需要评估其在测试集上的性能。我们将计算模型的准确率。

1
2
3
4
5
6
7
8
from sklearn.metrics import accuracy_score

# 在测试集上进行预测
y_pred = knn.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("模型准确率: {:.2f}%".format(accuracy * 100))

通过这段代码,我们可以得到模型的准确率,帮助我们验证模型的有效性。

可视化分类结果

为了更好地理解模型的性能,我们可以可视化一些预测结果。我们将绘制真实标签和模型预测标签:

1
2
3
4
5
6
7
8
# 可视化部分测试结果
plt.figure(figsize=(12, 6))
for index in range(10):
plt.subplot(2, 10, index + 1)
plt.imshow(X_test[index].reshape(8, 8), cmap='gray')
plt.title(f'真实: {y_test[index]}\n预测: {y_pred[index]}')
plt.axis('off')
plt.show()

这里的代码将绘制第一排为真实标签的手写数字,第二排为模型的预测结果。在真实预测与模型输出之间进行比较,能够直观地看到模型的性能。

总结与展望

在这一节中,我们使用 Scikit-Learn 实现了一个简单的手写数字识别模型。我们从数据加载、预处理,到模型训练和评估,每一步都进行了详细的说明。通过 K近邻算法,我们取得了令人满意的准确率。

接下来,我们将在后续的章节中探索更复杂的机器学习任务,比如客户分群。这将带来不同的挑战,帮助我们深入理解机器学习的多样性和应用。

在手写数字识别中,我们还可以考虑使用更复杂的模型,比如随机森林、支持向量机 (SVM) 或神经网络,以进一步提高性能。

通过这些探索,我们将不断加深对 Scikit-Learn 框架的理解和应用能力。

作者

IT教程网(郭震)

发布于

2024-08-15

更新于

2024-08-16

许可协议

分享转发

学习下节

复习上节

交流

更多教程加公众号

更多教程加公众号

加入星球获取PDF

加入星球获取PDF

打卡评论