11 深度学习框架之Keras
在上一篇文章中,我们介绍了深度学习框架之一的TensorFlow。在本篇教程中,我们将聚焦于Keras,一个基于Python的深度学习框架,它可以与TensorFlow无缝集成并提供更高级别的API,适合快速构建和训练深度学习模型。接下来,我们将通过案例和代码,探索Keras的基本用法。
Keras简介
Keras是一个高层深度学习API,旨在使构建神经网络变得简单、易于扩展并具可维护性。Keras支持多种后端,如TensorFlow、Theano和CNTK,但自从TensorFlow 2.0发布以来,Keras已成为TensorFlow的一部分,通常通过tf.keras
来使用。
Keras的优点
- 简易性:Keras提供直观的API,可以通过更少的代码实现复杂的神经网络模型。
- 模块化:你可以自由组合网络层、损失函数、优化器等,支持更加灵活的模型设计。
- 社区支持:作为一个广泛使用的库,Keras有强大的用户社区和丰富的文档资源。
Keras基本组件
在使用Keras构建深度学习模型时,你需要了解以下基本组件:
- 模型:Keras提供了两种主要的模型类型:
Sequential
模型和Functional
模型。 - 层:模型由多个层(
Layers
)组成,每一层执行特定的功能,例如卷积、池化、全连接等。 - 损失函数:用于评估模型的输出与真实值之间的差距。
- 优化器:用于更新模型参数以减少损失函数的值。
- 评估指标:用于监测模型性能的指标。
Keras使用示例
下面我们将通过一个简单的例子来演示如何使用Keras构建一个分类模型。假设我们要使用手写数字数据集MNIST来训练一个神经网络
用于数字分类。
1. 数据准备
首先,我们需要加载MNIST数据集,并对数据进行预处理。
import numpy as np
from keras.datasets import mnist
from keras.utils import to_categorical
# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 数据预处理
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# 改变数据形状
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
# 将标签进行独热编码
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
2. 构建模型
我们使用Sequential
模型来构建一个简单的全连接神经网络。
from keras.models import Sequential
from keras.layers import Flatten, Dense
# 创建Sequential模型
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1))) # 输入层
model.add(Dense(128, activation='relu')) # 隐藏层
model.add(Dense(10, activation='softmax')) # 输出层
3. 编译模型
在训练之前,我们需要编译模型,指定损失函数、优化器和评估指标。
model.compile(loss='categorical_crossentropy', # 使用类别交叉熵作为损失函数
optimizer='adam', # 使用Adam优化器
metrics=['accuracy']) # 监测准确率
4. 训练模型
接下来,我们使用fit
方法来训练模型。
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)
5. 评估模型
模型训练完成后,可以在测试数据集上评估模型的性能。
loss, accuracy = model.evaluate(x_test, y_test)
print(f'Test loss: {loss:.4f}, Test accuracy: {accuracy:.4f}')
小结
在本教程中,我们介绍了Keras的基本概念及其组件,并通过MNIST手写数字识别的示例详细展示了如何构建、编译、训练及评估一个简单的神经网络模型。Keras的简易性和灵活性使得深度学习变得更加高效和便捷。
接下来,我们将在下一篇文章中讨论另一个流行的深度学习框架——PyTorch。请保持关注!