16 Keras框架从零教程系列:模型训练之回调函数
在上一篇中,我们讨论了如何使用 fit
方法进行模型训练。这一篇将专注于 Keras
中的回调函数,这些函数在训练期间提供了一种灵活的方式,帮助我们监控模型训练状态、保存模型、调整学习率等。
什么是回调函数?
在 Keras
中,回调函数是执行特定操作的一组机制,这些操作会在模型训练的不同阶段被自动触发。它们可以在训练开始之前、每个训练周期结束之后、每个训练批次结束之后等时刻执行。借助回调函数,我们可以轻松扩展模型的训练过程。
常用的回调函数
1. EarlyStopping
EarlyStopping
用于监控模型的性能,当性能不再提升时,它会自动停止训练。它的主要参数包括:
monitor
:监控的指标(如val_loss
或val_accuracy
)。patience
:在监测指标未改善时,允许的耐心周期数。verbose
:是否输出详细信息。mode
:监测指标的最佳模式(min
或max
)。
示例代码:
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping
# 创建一个简单的模型
model = Sequential()
model.add(Dense(10, activation='relu', input_shape=(20,)))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 定义EarlyStopping回调
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1)
# 模拟数据
import numpy as np
X_train = np.random.random((1000, 20))
y_train = np.random.randint(2, size=(1000, 1))
X_val = np.random.random((200, 20))
y_val = np.random.randint(2, size=(200, 1))
# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[early_stopping])
2. ModelCheckpoint
ModelCheckpoint
用于在每个周期后保存模型。您可以选择在验证集上性能最好的模型被保存,或者按周期保存模型。
示例代码:
from keras.callbacks import ModelCheckpoint
# 定义ModelCheckpoint回调
checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1)
# 训练模型并保存最佳模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[checkpoint])
3. ReduceLROnPlateau
ReduceLROnPlateau
在监测指标不再提升时降低学习率,对改善模型性能非常有帮助。
示例代码:
from keras.callbacks import ReduceLROnPlateau
# 定义ReduceLROnPlateau回调
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, verbose=1, min_lr=1e-6)
# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[reduce_lr])
4. TensorBoard
TensorBoard
提供了可视化工具,帮助我们更好地理解模型训练的过程。我们可以监控损失、准确率,甚至自定义的指标。
示例代码:
from keras.callbacks import TensorBoard
# 定义TensorBoard回调
tensorboard = TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True)
# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[tensorboard])
结合使用回调函数
在实际训练过程中,我们可以组合多个回调函数来优化模型训练。例如,结合 EarlyStopping
, ModelCheckpoint
和 ReduceLROnPlateau
,可以在验证损失不再减少时自动停止训练并保存最佳模型,同时在必要时调整学习率。
示例代码:
callbacks = [
EarlyStopping(monitor='val_loss', patience=3, verbose=1),
ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1),
ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, verbose=1, min_lr=1e-6),
TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True)
]
# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=callbacks)
小结
在本篇教程中,我们详细介绍了 Keras
中的回调函数,它们是增强模型训练能力的重要工具。通过合理地使用这些回调函数,我们不但可以监控训练过程,还能改善模型性能,为后续的模型评估与预测打下良好的基础。
在下一篇中,我们将讨论如何评估训练好的模型的性能以及如何进行预测,敬请期待!