【强化学习】SARSA(lambda)与SARSA区别及python代码实现

一、概念介绍

单步更新:SARSA是一种单步更新法,每走一步,更新一下自己的行为准则。虽然每一步都在进行更新,但没有获得最终奖励的时候现在所处的的这一步也没获得更新,直到获得最终奖励,获得最终奖励的前一步认为和获得奖励是有关联的。

回合更新:SARSA(lambda)用来代替我们想选择的步数。获得最终奖励后才会进行更新,但是获得奖励的每一步都被认为和获得奖励是有关联的。λ是一个局部搜索的权重值-衰变值,离λ将离越近越重要。λ取0就成了单步更新,λ取1就变形成了回合更新。

二 代码实现

  1. Brain

import numpy as np

import pandas as pd

class RL(object):

    def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):#定义变量

        self.actions = action_space  # a list

        self.lr = learning_rate

        self.gamma = reward_decay

        self.epsilon = e_greedy

        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):#检查状态,有没有缺失

        if state not in self.q_table.index:

            # append new state to q table

            self.q_table = self.q_table.append(

                pd.Series(

                    [0]*len(self.actions),

                    index=self.q_table.columns,

                    name=state,

                )

            )

    def choose_action(self, observation):#选择动作

        self.check_state_exist(observation)

        # action selection

        if np.random.rand() < self.epsilon:

            # 选择最优动作

            state_action = self.q_table.loc[observation, :]

            # some actions may have the same value, randomly choose on in these actions

            action = np.random.choice(state_action[state_action == np.max(state_action)].index)

        else:

            # 随机选择动作

            action = np.random.choice(self.actions)

        return action

    def learn(self, *args):

        pass

# 向后看的方式,离奖励越近越重要

class SarsaLambdaTable(RL):

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):

        super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

      

        self.lambda_ = trace_decay#0-1的值

        self.eligibility_trace = self.q_table.copy()#state-action的表

    def check_state_exist(self, state):

        if state not in self.q_table.index:

            # 增加新的state

            to_be_append = pd.Series(

                    [0] * len(self.actions),

                    index=self.q_table.columns,

                    name=state,

                )

            self.q_table = self.q_table.append(to_be_append)

            # also update eligibility trace

            self.eligibility_trace = self.eligibility_trace.append(to_be_append)

    def learn(self, s, a, r, s_, a_):

        self.check_state_exist(s_)

        q_predict = self.q_table.loc[s, a]

        if s_ != 'terminal':

            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # next state is not terminal

        else:

            q_target = r  # next state is terminal

        error = q_target - q_predict

        # increase trace amount for visited state-action pair

        # Method 1:无封顶

        # self.eligibility_trace.loc[s, a] += 1

        # Method 2:#有封顶

        self.eligibility_trace.loc[s, :] *= 0

        self.eligibility_trace.loc[s, a] = 1

        # Q update

        self.q_table += self.lr * error * self.eligibility_trace

        # decay eligibility trace after update

        self.eligibility_trace *= self.gamma*self.lambda_

2.Test

from maze_env import Maze

from RL_brain import SarsaLambdaTable

def update():

    for episode in range(100):

        # 初始观测值

        observation = env.reset()

        # 基于观测值选择动作

        action = RL.choose_action(str(observation))

        # 开始均为0

        RL.eligibility_trace *= 0

        while True:

            # 更新环境

            env.render()

            # 采取动作得到下一步观测值和奖励

            observation_, reward, done = env.step(action)

            # 基于观测进行动作的选择

            action_ = RL.choose_action(str(observation_))

            # RL learn from this transition (s, a, r, s, a) ==> Sarsa

            RL.learn(str(observation), action, reward, str(observation_), action_)

            # 更新观测和动作

            observation = observation_

            action = action_

            # 当回合结束时进行打断

            if done:

                break

    # 结束游戏

    print('game over')

    env.destroy()

if __name__ == "__main__":

    env = Maze()

    RL = SarsaLambdaTable(actions=list(range(env.n_actions)))

    env.after(100, update)

    env.mainloop()

上一篇:Java应用日志如何与Jaeger的trace关联


下一篇:【Spring Cloud Alibaba】Sleuth + Zipkin 链路追踪