在上一篇中,我们讨论了如何使用 fit
方法进行模型训练。这一篇将专注于 Keras
中的回调函数,这些函数在训练期间提供了一种灵活的方式,帮助我们监控模型训练状态、保存模型、调整学习率等。
什么是回调函数?
在 Keras
中,回调函数是执行特定操作的一组机制,这些操作会在模型训练的不同阶段被自动触发。它们可以在训练开始之前、每个训练周期结束之后、每个训练批次结束之后等时刻执行。借助回调函数,我们可以轻松扩展模型的训练过程。
常用的回调函数
1. EarlyStopping
EarlyStopping
用于监控模型的性能,当性能不再提升时,它会自动停止训练。它的主要参数包括:
monitor
:监控的指标(如 val_loss
或 val_accuracy
)。
patience
:在监测指标未改善时,允许的耐心周期数。
verbose
:是否输出详细信息。
mode
:监测指标的最佳模式(min
或 max
)。
示例代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| from keras.models import Sequential from keras.layers import Dense from keras.callbacks import EarlyStopping
model = Sequential() model.add(Dense(10, activation='relu', input_shape=(20,))) model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1)
import numpy as np X_train = np.random.random((1000, 20)) y_train = np.random.randint(2, size=(1000, 1)) X_val = np.random.random((200, 20)) y_val = np.random.randint(2, size=(200, 1))
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[early_stopping])
|
2. ModelCheckpoint
ModelCheckpoint
用于在每个周期后保存模型。您可以选择在验证集上性能最好的模型被保存,或者按周期保存模型。
示例代码:
1 2 3 4 5 6 7
| from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1)
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[checkpoint])
|
3. ReduceLROnPlateau
ReduceLROnPlateau
在监测指标不再提升时降低学习率,对改善模型性能非常有帮助。
示例代码:
1 2 3 4 5 6 7
| from keras.callbacks import ReduceLROnPlateau
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, verbose=1, min_lr=1e-6)
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[reduce_lr])
|
4. TensorBoard
TensorBoard
提供了可视化工具,帮助我们更好地理解模型训练的过程。我们可以监控损失、准确率,甚至自定义的指标。
示例代码:
1 2 3 4 5 6 7
| from keras.callbacks import TensorBoard
tensorboard = TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True)
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[tensorboard])
|
结合使用回调函数
在实际训练过程中,我们可以组合多个回调函数来优化模型训练。例如,结合 EarlyStopping
, ModelCheckpoint
和 ReduceLROnPlateau
,可以在验证损失不再减少时自动停止训练并保存最佳模型,同时在必要时调整学习率。
示例代码:
1 2 3 4 5 6 7 8 9
| callbacks = [ EarlyStopping(monitor='val_loss', patience=3, verbose=1), ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1), ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, verbose=1, min_lr=1e-6), TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True) ]
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=callbacks)
|
小结
在本篇教程中,我们详细介绍了 Keras
中的回调函数,它们是增强模型训练能力的重要工具。通过合理地使用这些回调函数,我们不但可以监控训练过程,还能改善模型性能,为后续的模型评估与预测打下良好的基础。
在下一篇中,我们将讨论如何评估训练好的模型的性能以及如何进行预测,敬请期待!