Jupyter AI

16 Keras框架从零教程系列:模型训练之回调函数

📅 发表日期: 2024年8月15日

分类: 📚Keras 入门

👁️阅读: --

在上一篇中,我们讨论了如何使用 fit 方法进行模型训练。这一篇将专注于 Keras 中的回调函数,这些函数在训练期间提供了一种灵活的方式,帮助我们监控模型训练状态、保存模型、调整学习率等。

什么是回调函数?

Keras 中,回调函数是执行特定操作的一组机制,这些操作会在模型训练的不同阶段被自动触发。它们可以在训练开始之前、每个训练周期结束之后、每个训练批次结束之后等时刻执行。借助回调函数,我们可以轻松扩展模型的训练过程。

常用的回调函数

1. EarlyStopping

EarlyStopping 用于监控模型的性能,当性能不再提升时,它会自动停止训练。它的主要参数包括:

  • monitor:监控的指标(如 val_lossval_accuracy)。
  • patience:在监测指标未改善时,允许的耐心周期数。
  • verbose:是否输出详细信息。
  • mode:监测指标的最佳模式(minmax)。

示例代码:

from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping

# 创建一个简单的模型
model = Sequential()
model.add(Dense(10, activation='relu', input_shape=(20,)))
model.add(Dense(1, activation='sigmoid'))

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 定义EarlyStopping回调
early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1)

# 模拟数据
import numpy as np
X_train = np.random.random((1000, 20))
y_train = np.random.randint(2, size=(1000, 1))
X_val = np.random.random((200, 20))
y_val = np.random.randint(2, size=(200, 1))

# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[early_stopping])

2. ModelCheckpoint

ModelCheckpoint 用于在每个周期后保存模型。您可以选择在验证集上性能最好的模型被保存,或者按周期保存模型。

示例代码:

from keras.callbacks import ModelCheckpoint

# 定义ModelCheckpoint回调
checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1)

# 训练模型并保存最佳模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[checkpoint])

3. ReduceLROnPlateau

ReduceLROnPlateau 在监测指标不再提升时降低学习率,对改善模型性能非常有帮助。

示例代码:

from keras.callbacks import ReduceLROnPlateau

# 定义ReduceLROnPlateau回调
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, verbose=1, min_lr=1e-6)

# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[reduce_lr])

4. TensorBoard

TensorBoard 提供了可视化工具,帮助我们更好地理解模型训练的过程。我们可以监控损失、准确率,甚至自定义的指标。

示例代码:

from keras.callbacks import TensorBoard

# 定义TensorBoard回调
tensorboard = TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True)

# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[tensorboard])

结合使用回调函数

在实际训练过程中,我们可以组合多个回调函数来优化模型训练。例如,结合 EarlyStopping, ModelCheckpointReduceLROnPlateau,可以在验证损失不再减少时自动停止训练并保存最佳模型,同时在必要时调整学习率。

示例代码:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=3, verbose=1),
    ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, verbose=1, min_lr=1e-6),
    TensorBoard(log_dir='./logs', histogram_freq=1, write_graph=True)
]

# 训练模型
model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=callbacks)

小结

在本篇教程中,我们详细介绍了 Keras 中的回调函数,它们是增强模型训练能力的重要工具。通过合理地使用这些回调函数,我们不但可以监控训练过程,还能改善模型性能,为后续的模型评估与预测打下良好的基础。

在下一篇中,我们将讨论如何评估训练好的模型的性能以及如何进行预测,敬请期待!