一、概念介绍
单步更新:SARSA是一种单步更新法,每走一步,更新一下自己的行为准则。虽然每一步都在进行更新,但没有获得最终奖励的时候现在所处的的这一步也没获得更新,直到获得最终奖励,获得最终奖励的前一步认为和获得奖励是有关联的。
回合更新:SARSA(lambda)用来代替我们想选择的步数。获得最终奖励后才会进行更新,但是获得奖励的每一步都被认为和获得奖励是有关联的。λ是一个局部搜索的权重值-衰变值,离λ将离越近越重要。λ取0就成了单步更新,λ取1就变形成了回合更新。
二 代码实现
- Brain
import numpy as np
import pandas as pd
class RL(object):
def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):#定义变量
self.actions = action_space # a list
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
def check_state_exist(self, state):#检查状态,有没有缺失
if state not in self.q_table.index:
# append new state to q table
self.q_table = self.q_table.append(
pd.Series(
[0]*len(self.actions),
index=self.q_table.columns,
name=state,
)
)
def choose_action(self, observation):#选择动作
self.check_state_exist(observation)
# action selection
if np.random.rand() < self.epsilon:
# 选择最优动作
state_action = self.q_table.loc[observation, :]
# some actions may have the same value, randomly choose on in these actions
action = np.random.choice(state_action[state_action == np.max(state_action)].index)
else:
# 随机选择动作
action = np.random.choice(self.actions)
return action
def learn(self, *args):
pass
# 向后看的方式,离奖励越近越重要
class SarsaLambdaTable(RL):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
self.lambda_ = trace_decay#0-1的值
self.eligibility_trace = self.q_table.copy()#state-action的表
def check_state_exist(self, state):
if state not in self.q_table.index:
# 增加新的state
to_be_append = pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state,
)
self.q_table = self.q_table.append(to_be_append)
# also update eligibility trace
self.eligibility_trace = self.eligibility_trace.append(to_be_append)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal
else:
q_target = r # next state is terminal
error = q_target - q_predict
# increase trace amount for visited state-action pair
# Method 1:无封顶
# self.eligibility_trace.loc[s, a] += 1
# Method 2:#有封顶
self.eligibility_trace.loc[s, :] *= 0
self.eligibility_trace.loc[s, a] = 1
# Q update
self.q_table += self.lr * error * self.eligibility_trace
# decay eligibility trace after update
self.eligibility_trace *= self.gamma*self.lambda_
2.Test
from maze_env import Maze
from RL_brain import SarsaLambdaTable
def update():
for episode in range(100):
# 初始观测值
observation = env.reset()
# 基于观测值选择动作
action = RL.choose_action(str(observation))
# 开始均为0
RL.eligibility_trace *= 0
while True:
# 更新环境
env.render()
# 采取动作得到下一步观测值和奖励
observation_, reward, done = env.step(action)
# 基于观测进行动作的选择
action_ = RL.choose_action(str(observation_))
# RL learn from this transition (s, a, r, s, a) ==> Sarsa
RL.learn(str(observation), action, reward, str(observation_), action_)
# 更新观测和动作
observation = observation_
action = action_
# 当回合结束时进行打断
if done:
break
# 结束游戏
print('game over')
env.destroy()
if __name__ == "__main__":
env = Maze()
RL = SarsaLambdaTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()