16 模型训练之回调函数

在上一篇中,我们讨论了如何使用 fit 方法进行模型训练。这一篇将专注于 Keras 中的回调函数,这些函数在训练期间提供了一种灵活的方式,帮助我们监控模型训练状态、保存模型、调整学习率等。

什么是回调函数?

Keras 中,回调函数是执行特定操作的一组机制,这些操作会在模型训练的不同阶段被自动触发。它们可以在训练开始之前、每个训练周期结束之后、每个训练批次结束之后等时刻执行。借助回调函数,我们可以轻松扩展模型的训练过程。

常用的回调函数

1. EarlyStopping

EarlyStopping 用于监控模型的性能,当性能不再提升时,它会自动停止训练。它的主要参数包括:

  • monitor:监控的指标(如 val_lossval_accuracy)。
  • patience:在监测指标未改善时,允许的耐心周期数。
  • verbose:是否输出详细信息。
  • mode:监测指标的最佳模式(minmax)。

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping

# 创建一个简单的模型
model = Sequential()
model.add(Dense(10, activation='relu', input_shape=(20,)))
model.add(Dense(1, activation='sigmoid'))

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 定义EarlyStopping回调
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1)

# 模拟数据
import numpy as np
X_train = np.random.random((1000, 20))
y_train = np.random.randint(2, size=(1000, 1))
X_val = np.random.random((200, 20))
y_val = np.random.randint(2, size=(200, 1))

# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[early_stopping])

2. ModelCheckpoint

ModelCheckpoint 用于在每个周期后保存模型。您可以选择在验证集上性能最好的模型被保存,或者按周期保存模型。

示例代码:

1
2
3
4
5
6
7
from keras.callbacks import ModelCheckpoint

# 定义ModelCheckpoint回调
checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1)

# 训练模型并保存最佳模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[checkpoint])

3. ReduceLROnPlateau

ReduceLROnPlateau 在监测指标不再提升时降低学习率,对改善模型性能非常有帮助。

示例代码:

1
2
3
4
5
6
7
from keras.callbacks import ReduceLROnPlateau

# 定义ReduceLROnPlateau回调
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, verbose=1, min_lr=1e-6)

# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[reduce_lr])

4. TensorBoard

TensorBoard 提供了可视化工具,帮助我们更好地理解模型训练的过程。我们可以监控损失、准确率,甚至自定义的指标。

示例代码:

1
2
3
4
5
6
7
from keras.callbacks import TensorBoard

# 定义TensorBoard回调
tensorboard = TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True)

# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[tensorboard])

结合使用回调函数

在实际训练过程中,我们可以组合多个回调函数来优化模型训练。例如,结合 EarlyStopping, ModelCheckpointReduceLROnPlateau,可以在验证损失不再减少时自动停止训练并保存最佳模型,同时在必要时调整学习率。

示例代码:

1
2
3
4
5
6
7
8
9
callbacks = [
EarlyStopping(monitor='val_loss', patience=3, verbose=1),
ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1),
ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, verbose=1, min_lr=1e-6),
TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True)
]

# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=callbacks)

小结

在本篇教程中,我们详细介绍了 Keras 中的回调函数,它们是增强模型训练能力的重要工具。通过合理地使用这些回调函数,我们不但可以监控训练过程,还能改善模型性能,为后续的模型评估与预测打下良好的基础。

在下一篇中,我们将讨论如何评估训练好的模型的性能以及如何进行预测,敬请期待!

16 模型训练之回调函数

https://zglg.work/keras-zero/16/

作者

IT教程网(郭震)

发布于

2024-08-15

更新于

2024-08-16

许可协议

分享转发

交流

更多教程加公众号

更多教程加公众号

加入星球获取PDF

加入星球获取PDF

打卡评论