15 训练循环

15 训练循环

在使用 PyTorch 进行模型训练时,训练循环 是一个核心概念。它决定了如何通过多次迭代数据来优化模型参数。以下是一些关键步骤和细节,帮助你从零开始理解和实现训练循环。

1. 初始化模型和数据

在训练开始之前,我们需要定义模型、数据集和优化器。通常,这些操作在循环开始之前完成。

示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 定义模型
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc = nn.Linear(10, 1) # 输入尺寸为 10,输出尺寸为 1

def forward(self, x):
return self.fc(x)

# 创建数据集
X = torch.randn(100, 10) # 100 个样本,每个样本有 10 个特征
y = torch.randn(100, 1) # 100 个目标值
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# 初始化模型、损失函数和优化器
model = SimpleNN()
criterion = nn.MSELoss() # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降优化

2. 训练循环

基本结构

训练循环通常包括以下几个部分:

  1. 迭代数据批次
  2. 前向传播
  3. 计算损失值
  4. 反向传播
  5. 更新模型参数
  6. 记录和输出信息

示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
num_epochs = 5  # 总共训练的周期(epoch)数

for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(dataloader):

# ------- 1. 前向传播 -------
output = model(data) # 将输入数据传入模型

# ------- 2. 计算损失 -------
loss = criterion(output, target) # 计算损失值

# ------- 3. 清零梯度 -------
optimizer.zero_grad() # 清除之前的梯度

# ------- 4. 反向传播 -------
loss.backward() # 计算当前梯度

# ------- 5. 更新参数 -------
optimizer.step() # 更新模型参数

# ------- 6. 输出信息 -------
if batch_idx % 10 == 0: # 每 10 个批次输出一次信息
print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{batch_idx + 1}/{len(dataloader)}], Loss: {loss.item():.4f}')

3. 注意事项

  • 清零梯度:每个训练循环开始前,使用 optimizer.zero_grad() 清空旧的梯度。否则累积的梯度会导致模型参数更新不准确。

  • 损失函数:选择合适的损失函数是关键,它会影响模型的训练效果。例如,分类问题通常使用交叉熵损失,而回归问题则使用均方误差损失。

  • 学习率:选择合适的学习率(如lr)可以显著影响训练速度和最终模型的性能。使用优化器的学习率调度器可以自动调整学习率。

4. 评估模型

在训练循环结束后,通常需要对模型进行评估(验证集或测试集)。这可以帮助我们了解模型的泛化能力。

示例代码

1
2
3
4
5
6
7
8
9
10
# 在验证集上评估模型
model.eval() # 设定模型为评估模式
with torch.no_grad(): # 不计算梯度
total_loss = 0
for val_data, val_target in val_dataloader: # val_dataloader 是你的验证集 DataLoader
val_output = model(val_data)
val_loss = criterion(val_output, val_target)
total_loss += val_loss.item()

print(f'Validation Loss: {total_loss / len(val_dataloader):.4f}')

通过理解和实现 训练循环,你可以在 PyTorch 中高效地训练神经网络模型。这些步骤和示例代码将帮助你在实践中应用这一重要概念。

作者

AI教程网

发布于

2024-08-07

更新于

2024-08-10

许可协议