Gambler's Problem,即“赌徒问题”,是一个经典的动态编程里值迭代应用的问题。
在一个掷硬币游戏中,赌徒先下注,如果硬币为正面,赌徒赢回双倍,若是反面,则输掉赌注。赌徒给自己定了一个目标,本金赢到100块或是输光就结束游戏。找到一个关于本金与赌注之间关系的策略使得赌徒最快赢到100块。状态s = {1, 2, 3...., 99, 100},动作a = {1, 2, 3, ...., min(s, 100 - s)}。奖励设置:只有当赌徒赢到100块时奖励+1,其余状态奖励为0。
这个问题并不难,最优policy一定是min(s, 100-s),这里就不分析了,直接给出计算程序
clear
clc
%% Initialize
Q = zeros(101);
ActionProb = Q + 1/100;
V = zeros(1, 101);
R = V;
R(1, 101) = 1;
V = R;
hp = 0.4;
i = 0;
delta = 100;
gamma = 0.5;
capital = [1:99];
num = 1; %% Value Iteration
while(num < 10)
while(i < num)
delta = 0;
capital = [1:99];
for state = [1:99]
actions = [1:min(capital(state), 100 - capital(state))];
PossibleStateLose = capital(state) - actions + 1;
PossibleStateWin = capital(state) + actions + 1;
%Q(state + 1, actions) = gamma*(hp*V(PossibleStateWin) + (1 - hp)*V(PossibleStateLose)) + R(PossibleStateWin) + R(PossibleStateLose);
Q(state + 1, actions) = hp*V(PossibleStateWin) + (1 - hp)*V(PossibleStateLose);
[MAX index] = max(Q(state + 1, :));
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Softmax Policy:
%ActionProb(state, :) = 0;
%ActionProb(state, :) = exp(Q(state, :)/0.02)/sum(exp(Q(state, :)/0.02));
%R(state + 1) = ActionProb(state, :)*Q(state, :)';
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
V(state + 1) = MAX;
end
i = i + 1;
end
plot(V, 'LineWidth', 2)
hold on
num = num + 1;
grid on
end
%%
figure
for state = 1:100
[MAX index] = max(Q(state, :));
Map(state) = index;
plot(state, index, 'bo')
hold on
end %%Test Part
iter = 1;
count = zeros(1, 100);
flag = count;
Mflag = zeros(1, 100);
while(iter < 1000)
Mflag = zeros(1, 100);
Mcount = Mflag; for state = 1:100
capital = state;
while(1)
if(capital >= 100)
break
end
stake = Map(capital);
%stake = min(capital, 100 - capital);
if(rand < 0.4)
capital = capital + stake;
else
capital = capital - stake;
end
if(capital <= 0)
flag(state) = flag(state) + 1;
Mflag(state) = Mflag(state) + 1;
break
else
count(state) = count(state) + 1;
Mcount(state) = Mcount(state) + 1;
end
end
end
%figure
%plot(find(flag~=1), count(find(flag ~= 1)), 'bo')
FT(iter) = sum(Mflag)/100;
ST(iter) = mean(Mcount(find(Mflag ~= 1)));
iter = iter + 1;
end
figure
plot(1 - flag/1000, 'bo')
figure
plot(count/1000)
mean(1-FT)
mean(ST)