Jupyter AI

14 强化学习之SARSA算法

📅 发表日期: 2024年8月15日

分类: 🤖强化学习入门

👁️阅读: --

在了解完时序差分学习的基本概念后,我们接下来将深入探讨一种具体的时序差分学习方法——SARSA(State-Action-Reward-State-Action)算法。SARSA 是一种在线的强化学习算法,它通过与环境的互动来学习状态-动作值函数,从而实现策略的改进。接下来,我们将通过理论、示例和代码,详细介绍 SARSA 算法的原理和实现。

1. SARSA算法的基本原理

SARSA的名称来源于它更新Q值的方式:它同时考虑当前状态、当前动作、奖励、下一个状态和下一个动作。具体而言,SARSA算法的核心更新公式为:

Q(st,at)Q(st,at)+α(rt+γQ(st+1,at+1)Q(st,at))Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left( r_t + \gamma Q(s_{t+1}, a_{t+1}) - Q(s_t, a_t) \right)

其中:

  • sts_t 是在时间 tt 的状态。
  • ata_t 是在时间 tt 采取的动作。
  • rtr_t 是在状态 sts_t 采取动作 ata_t 后获得的奖励。
  • st+1s_{t+1} 是在时间 t+1t+1 的状态。
  • at+1a_{t+1} 是在时间 t+1t+1 依据当前策略选择的动作。
  • α\alpha 是学习率,用于控制新信息对旧信息的更新。
  • γ\gamma 是折扣因子,用于平衡当前奖励与未来奖励的重要性。

1.1 SARSA算法的特点

  • 在线学习:SARSA 是一种在线学习算法,意味着 agente 将持续更新其策略,而不是在事后进行训练。
  • 探索与利用:通过 ϵ\epsilon-贪婪策略,SARSA 进行探索和利用的权衡,确保在学习过程中不会陷入局部最优。

2. 实际案例

为了更好地理解SARSA算法,我们可以考虑一个简化的迷宫问题,其中代理需要从起点移动到终点。在每个步骤中,它可以选择向上、下、左或右移动,并根据移动的结果得到奖励。我们的目标是通过SARSA算法来找到最优策略。

2.1 迷宫环境的设置

假设我们的迷宫如下所示,其中 S 是起点,G 是终点,-1 表示墙,0 表示可通行的路径:

S  0  0  0
0 -1  0  G
0  0 -1  0

奖励设置

  • 到达 G 的奖励是 +10+10
  • 每移动一步的奖励是1-1
  • 碰到墙的奖励是1-1

2.2 SARSA算法的实现

以下是使用 Python 和 NumPy 实现 SARSA 算法的一个简单示例:

import numpy as np
import random

# 环境设置
maze = np.array([[0, 0, 0, 10],
                 [0, -1, 0, -1],
                 [0, 0, -1, 0]])  # 0:可通行, -1:墙, 10:目标
actions = [0, 1, 2, 3]  # 上、下、左、右
q_table = np.zeros((3, 4, len(actions)))  # (状态数, 动作数)

# 超参数
alpha = 0.1  # 学习率
gamma = 0.9  # 折扣因子
epsilon = 0.1  # 探索率

def choose_action(state):
    if random.uniform(0, 1) < epsilon:  # 探索
        return random.choice(actions)
    else:  # 利用
        return np.argmax(q_table[state[0], state[1]])

def update_q_table(state, action, reward, next_state, next_action):
    q_table[state[0], state[1], action] += alpha * (
        reward + gamma * q_table[next_state[0], next_state[1], next_action] - q_table[state[0], state[1], action]
    )

# 训练
for episode in range(1000):
    state = (0, 0)  # 初始化状态为起点
    action = choose_action(state)
    
    while True:
        next_state = (state[0] + (action == 0) - (action == 1), 
                      state[1] + (action == 3) - (action == 2))  # 更新状态
                      
        # 确保新状态在边界内
        if next_state[0] < 0 or next_state[0] >= maze.shape[0] or next_state[1] < 0 or next_state[1] >= maze.shape[1]:
            next_state = state
        
        reward = maze[next_state] if maze[next_state] != -1 else -1  # 碰撞墙壁的情况
        next_action = choose_action(next_state)  # 根据新状态选择下一个动作
        
        # 更新Q表
        update_q_table(state, action, reward, next_state, next_action)
        
        state = next_state
        action = next_action
        
        if maze[state] == 10:  # 如果到达目标
            break

# 打印Q表
print("学习后的Q表:")
print(q_table)

3. 总结

SARSA算法作为一种基于时序差分学习的强化学习方法,能够有效地通过与环境的交互逐步学习到最优策略。在迷宫问题中,SARSA通过不断更新状态-动作值函数,不仅平衡了探索与利用,还在复杂环境中逐步逼近最优策略。在下一篇中,我们将探讨另一种重要的时序差分学习算法——Q学习,帮助大家深入理解这一领域。