plt.style.use('ggplot')
# fig, ax = plt.subplots(1,len(scenarios), figsize=(25,4.5)) # plt.show()
# colors = ["#468de4", 'orange', 'green', '#e44a44', 'purple'] colors = ['orange', 'purple']
data = {"Ave_reward":[], "Step":[]} reward = []
# name = './tikick/{}_{}.csv'.format(file_name[i], aname) name = ['PongNoFrameskip-v0_0_bcq.csv','PongNoFrameskip-v0_0_bcq_0.csv']
# f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
#************** below is the first figure ***************#
file = pd.read_csv(name[0]) for key in file: if "avg_reward" in key: reward.append(file[key].tolist())
plt_step = np.array(pd.read_csv(name[0])["step"].tolist()) plt_reward = np.array(reward)
data["Ave_reward"] = plt_reward.mean(0) data["step"] = plt_step
ax = plt.subplot(121)
font_size = 10 x_step = plt_step.tolist() y_mean = data["Ave_reward"].tolist() y_max = plt_reward.max(0).tolist() y_min = plt_reward.min(0).tolist() # sns.lineplot(x_step, y_mean) sns.lineplot(x_step, y_mean, ax=ax, legend=True, label = u'RCSCA_Net', lw=3, color=colors[0]) # plt.fill_between(x_step,y_max,y_min,facecolor="green",edgecolor="red",alpha=0.3) num = plt.fill_between(x_step,y_max,y_min,alpha=0.2,label='bcq')
ax.set_title("BCQ", fontsize=font_size)
ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.set_xlabel('Time steps', fontdict={'size':20}) ax.tick_params(labelsize=font_size) x_name = 'Average Return' ax.set_ylabel(x_name, fontdict={'size':20})
plt.legend() # fig.legend((num,), ('upper right',),'upper center',ncol=5) # plt.legend(bbox_to_anchor=(-0.5, 1.40), ncol=2, labels='TiKick(ours)', \ # fontsize=22, frameon=True, facecolor='white', framealpha=1)
# ax.grid(True)
#************** below is the second figure ***************#
file = pd.read_csv(name[1]) for key in file: if "avg_reward" in key: reward.append(file[key].tolist())
plt_step = np.array(pd.read_csv(name[1])["step"].tolist()) plt_reward = np.array(reward)
data["Ave_reward"] = plt_reward.mean(0) data["step"] = plt_step
ax = plt.subplot(121)
font_size = 10 x_step = plt_step.tolist() y_mean = data["Ave_reward"].tolist() y_max = plt_reward.max(0).tolist() y_min = plt_reward.min(0).tolist() # sns.lineplot(x_step, y_mean) sns.lineplot(x_step, y_mean, ax=ax, legend=True, label = u'RCSCA_Net',lw=3, color=colors[1]) # plt.fill_between(x_step,y_max,y_min,facecolor="green",edgecolor="red",alpha=0.3) num = plt.fill_between(x_step,y_max,y_min,alpha=0.2,label='bcq')
ax.set_title("bcq", fontsize=font_size)
ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.set_xlabel('Time steps', fontdict={'size':20}) ax.tick_params(labelsize=font_size) x_name = 'Average Return' # ax.set_ylabel(x_name, fontdict={'size':font_size})
# plt.legend() # fig.legend((num,), ('upper right',),'upper center',ncol=5) # plt.legend(bbox_to_anchor=(-0.5, 1.40), ncol=2, labels='TiKick(ours)', \ # fontsize=22, frameon=True, facecolor='white', framealpha=1)
ax.grid(True)
plt.show()
import pdb pdb.set_trace()