14 训练模型

在上一篇中,我们探讨了如何通过compile方法来为模型编译优化器、损失函数和指标。现在,我们将继续深入,学习如何训练模型,具体来说就是如何使用fit方法进行模型训练。

训练模型的准备

在开始训练模型之前,我们需要确保以下几项已经完成:

  1. 模型已创建并编译:您需要创建一个Keras模型,并使用compile方法定义优化器、损失函数和评价指标。
  2. 数据集已准备好:我们需要准备好输入数据 X 和目标输出数据 y

下面是一个简单的例子,展示了如何创建一个模型并进行编译。

1
2
3
4
5
6
7
8
9
10
11
12
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np

# 创建一个简单的神经网络模型
model = Sequential()
model.add(Dense(64, activation='relu', input_shape=(10,))) # 输入层
model.add(Dense(1, activation='sigmoid')) # 输出层

# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

接下来,我们需要准备数据。假设我们有一组特征数据 X 和对应的标签 y

1
2
X = np.random.rand(1000, 10)  # 1000个样本,每个样本10个特征
y = np.random.randint(2, size=(1000, 1)) # 1000个样本的二元标签

使用fit方法训练模型

一旦我们有了数据和编译好的模型,我们就可以使用fit方法进行训练了。fit方法的基本语法如下:

1
model.fit(x, y, epochs=1, batch_size=None, verbose=1, validation_split=0.0, callbacks=None, ...)
  • x:训练数据。
  • y:目标输出数据。
  • epochs:训练的轮次,默认值为1。
  • batch_size:每个批次的样本数量,默认值为32。
  • verbose:训练日志显示模式,0为安静,1为进度条,2为每个epoch一行。
  • validation_split:用于验证的比例,将一部分训练数据分离出来用于验证。

下面是用fit方法训练模型的例子:

1
2
# 训练模型
model.fit(X, y, epochs=10, batch_size=32, verbose=1, validation_split=0.2)

在这个例子中,我们将训练模型的轮次设置为10,每个批次为32个样本,并将20%的数据用于验证。

训练过程中的监控

在调用fit方法时,Keras会自动在每个epoch结束时评估模型在训练集和验证集上的损失和准确率。这样,您就能在训练过程中监控模型的表现。

使用回调函数

在训练过程中,您可能希望在特定条件下保存模型或调整学习率。Keras提供了回调函数 callbacks,允许您在训练过程中扩展模型的功能。

1
2
3
4
5
6
7
8
9
10
11
from keras.callbacks import EarlyStopping, ModelCheckpoint

# 提前停止的回调函数
early_stopping = EarlyStopping(monitor='val_loss', patience=3)

# 模型检查点回调函数
model_checkpoint = ModelCheckpoint('best_model.h5', save_best_only=True)

# 训练模型,并应用回调函数
model.fit(X, y, epochs=50, batch_size=32, verbose=1, validation_split=0.2,
callbacks=[early_stopping, model_checkpoint])

在这个例子中,EarlyStopping回调会监控验证损失,如果在3个连续epoch内没有改善,就会提前停止训练。同时,ModelCheckpoint会保存验证损失最小的模型,以便后续加载和使用。

训练完成后的步骤

训练完成后,您可以使用evaluate方法测试模型的性能,并使用predict方法进行预测。例如:

1
2
3
4
5
6
# 评估模型性能
loss, accuracy = model.evaluate(X, y)
print(f'Loss: {loss}, Accuracy: {accuracy}')

# 进行预测
predictions = model.predict(X)

您可以把这些步骤整合到您的工作流程中,以确保您的模型训练、评估和使用均高效有序。

小结

在本篇中,我们探讨了如何使用Keras中的fit方法来训练模型,监控训练过程,并利用回调函数增强训练过程的功能。这为后续的模型使用和调优打下了良好的基础。

在下一篇中,我们将重点讨论如何在Keras中使用fit方法的高级功能,以提高模型的训练效果和效率。希望能够继续与您一起深入探讨Keras框架的奥秘!

作者

IT教程网(郭震)

发布于

2024-08-15

更新于

2024-08-16

许可协议

分享转发

复习上节

交流

更多教程加公众号

更多教程加公众号

加入星球获取PDF

加入星球获取PDF

打卡评论