14 强化学习之SARSA算法

在了解完时序差分学习的基本概念后,我们接下来将深入探讨一种具体的时序差分学习方法——SARSA(State-Action-Reward-State-Action)算法。SARSA 是一种在线的强化学习算法,它通过与环境的互动来学习状态-动作值函数,从而实现策略的改进。接下来,我们将通过理论、示例和代码,详细介绍 SARSA 算法的原理和实现。

1. SARSA算法的基本原理

SARSA的名称来源于它更新Q值的方式:它同时考虑当前状态、当前动作、奖励、下一个状态和下一个动作。具体而言,SARSA算法的核心更新公式为:

$$
Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left( r_t + \gamma Q(s_{t+1}, a_{t+1}) - Q(s_t, a_t) \right)
$$

其中:

  • $s_t$ 是在时间 $t$ 的状态。
  • $a_t$ 是在时间 $t$ 采取的动作。
  • $r_t$ 是在状态 $s_t$ 采取动作 $a_t$ 后获得的奖励。
  • $s_{t+1}$ 是在时间 $t+1$ 的状态。
  • $a_{t+1}$ 是在时间 $t+1$ 依据当前策略选择的动作。
  • $\alpha$ 是学习率,用于控制新信息对旧信息的更新。
  • $\gamma$ 是折扣因子,用于平衡当前奖励与未来奖励的重要性。

1.1 SARSA算法的特点

  • 在线学习:SARSA 是一种在线学习算法,意味着 agente 将持续更新其策略,而不是在事后进行训练。
  • 探索与利用:通过 $\epsilon$-贪婪策略,SARSA 进行探索和利用的权衡,确保在学习过程中不会陷入局部最优。

2. 实际案例

为了更好地理解SARSA算法,我们可以考虑一个简化的迷宫问题,其中代理需要从起点移动到终点。在每个步骤中,它可以选择向上、下、左或右移动,并根据移动的结果得到奖励。我们的目标是通过SARSA算法来找到最优策略。

2.1 迷宫环境的设置

假设我们的迷宫如下所示,其中 S 是起点,G 是终点,-1 表示墙,0 表示可通行的路径:

1
2
3
S  0  0  0
0 -1 0 G
0 0 -1 0

奖励设置

  • 到达 G 的奖励是 $+10$。
  • 每移动一步的奖励是$-1$。
  • 碰到墙的奖励是$-1$。

2.2 SARSA算法的实现

以下是使用 Python 和 NumPy 实现 SARSA 算法的一个简单示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import random

# 环境设置
maze = np.array([[0, 0, 0, 10],
[0, -1, 0, -1],
[0, 0, -1, 0]]) # 0:可通行, -1:墙, 10:目标
actions = [0, 1, 2, 3] # 上、下、左、右
q_table = np.zeros((3, 4, len(actions))) # (状态数, 动作数)

# 超参数
alpha = 0.1 # 学习率
gamma = 0.9 # 折扣因子
epsilon = 0.1 # 探索率

def choose_action(state):
if random.uniform(0, 1) < epsilon: # 探索
return random.choice(actions)
else: # 利用
return np.argmax(q_table[state[0], state[1]])

def update_q_table(state, action, reward, next_state, next_action):
q_table[state[0], state[1], action] += alpha * (
reward + gamma * q_table[next_state[0], next_state[1], next_action] - q_table[state[0], state[1], action]
)

# 训练
for episode in range(1000):
state = (0, 0) # 初始化状态为起点
action = choose_action(state)

while True:
next_state = (state[0] + (action == 0) - (action == 1),
state[1] + (action == 3) - (action == 2)) # 更新状态

# 确保新状态在边界内
if next_state[0] < 0 or next_state[0] >= maze.shape[0] or next_state[1] < 0 or next_state[1] >= maze.shape[1]:
next_state = state

reward = maze[next_state] if maze[next_state] != -1 else -1 # 碰撞墙壁的情况
next_action = choose_action(next_state) # 根据新状态选择下一个动作

# 更新Q表
update_q_table(state, action, reward, next_state, next_action)

state = next_state
action = next_action

if maze[state] == 10: # 如果到达目标
break

# 打印Q表
print("学习后的Q表:")
print(q_table)

3. 总结

SARSA算法作为一种基于时序差分学习的强化学习方法,能够有效地通过与环境的交互逐步学习到最优策略。在迷宫问题中,SARSA通过不断更新状态-动作值函数,不仅平衡了探索与利用,还在复杂环境中逐步逼近最优策略。在下一篇中,我们将探讨另一种重要的时序差分学习算法——Q学习,帮助大家深入理解这一领域。

作者

IT教程网(郭震)

发布于

2024-08-15

更新于

2024-08-16

许可协议

分享转发

交流

更多教程加公众号

更多教程加公众号

加入星球获取PDF

加入星球获取PDF

打卡评论