17 使用Keras构建简单模型

在上篇的内容中,我们了解了Keras是什么以及它的基本概念。当前篇将带你走进Keras的实际应用,具体是如何构建一个简单的模型。我们会通过一个经典的案例,即手写数字识别(MNIST数据集),来演示如何用Keras构建模型。

Keras基本组成部分

Keras是一个高层次的神经网络API,能够以简单和高效的方式构建和训练深度学习模型。构建模型主要有以下几个重要步骤:

  1. 定义模型:选择模型类型(如顺序模型或函数式模型)。
  2. 添加层:向模型中逐层添加神经网络层。
  3. 编译模型:指定损失函数、优化器和评估指标。
  4. 训练模型:通过训练数据拟合模型。
  5. 评估与预测:使用测试数据评估模型性能,进行预测。

在这一过程中,我们主要使用Sequential模型,它是Keras提供的最简单形式,适合于逐层叠加的神经网络。

构建手写数字识别模型

步骤1:导入必要的库

首先,确保你已经安装了TensorFlow和Keras。接下来,导入我们需要的库:

1
2
3
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist

步骤2:加载和预处理数据

Keras提供了MNIST数据集的方便方法,我们可以直接加载,并进行预处理。首先加载数据:

1
2
3
4
5
6
# 加载数据
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 数据规范化处理
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0

在这里,我们将图像数据的像素值从0-255缩放到0-1之间,以帮助模型更快收敛。

步骤3:构建模型

接下来,我们将构建一个基本的神经网络模型。我们将使用一个包含两个隐藏层的顺序模型:

1
2
3
4
model = models.Sequential()
model.add(layers.Flatten(input_shape=(28, 28))) # 将28x28的图像展平为784的向量
model.add(layers.Dense(128, activation='relu')) # 第一个全连接层,使用ReLU激活函数
model.add(layers.Dense(10, activation='softmax')) # 输出层,10个类别,使用Softmax激活函数

步骤4:编译模型

编译模型时,我们需要选择损失函数、优化器和评估指标。对于分类问题,我们通常使用Categorical Crossentropy损失函数,并选择Adam优化器:

1
2
3
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

步骤5:训练模型

现在,我们可以使用训练数据来训练模型。使用fit方法,我们可以指定训练周期数(epochs)和每个周期的批量大小(batch size):

1
model.fit(train_images, train_labels, epochs=5, batch_size=32)

在这里,我们设置训练5个周期,批量大小为32。

模型评估与预测

训练完成后,我们可以通过测试数据集评估模型性能:

1
2
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f'Test accuracy: {test_acc}')

评估完成后,如果我们想要对新的数据进行预测,可以使用predict方法:

1
2
predictions = model.predict(test_images)
predicted_class = tf.argmax(predictions, axis=1)

这里,predictions包含对每个测试图片的类别概率,我们可以使用argmax来找到概率最大的类别。

总结

通过上述简单的步骤,我们已经成功构建了一个使用Keras的手写数字识别模型。从数据加载到模型训练,我们看到了Keras构建深度学习模型的基本流程。这只是一个开始,在下一篇中,我们将深入学习如何训练模型的基本步骤,包括如何进行模型调优和超参数调整。

请继续关注后续内容,深入理解如何训练和优化你的模型。

17 使用Keras构建简单模型

https://zglg.work/tensorflow-zero/17/

作者

IT教程网(郭震)

发布于

2024-08-10

更新于

2024-08-11

许可协议

分享转发

复习上节

交流

更多教程加公众号

更多教程加公众号

加入星球获取PDF

加入星球获取PDF

打卡评论