Pytorch学习

Pytorch学习 -- 深度学习

一级目录

二级目录

三级目录

Pytorch必须在init初始化网络结构 forward中做feed forward网络的前馈
创建网络结构代码
待更新知识点学习:

  1. 张量tensor的各种操作.argmax() add() 等等 link
  2. nn.module 父类
  3. nn.Sequential()
  4. nn.module中的各层 nn.Linear()
  5. 激活函数 nn.ReLU() nn.LeackyReLU() nn.ELU() 等等
  6. 损失函数
  7. 优化器
import torch
import torch.nn as nn
from collections import OrderedDict

class Net(nn.Module):  # nn.Module is standard PyTorch Network
    def __init__(self,state_dim, mid_dim,  action_dim):
        '''
        相当于create model
        :param mid_dim: 中间层神经元数
        :param state_dim: 状态层神经元数 输入
        :param action_dim: 动作层神经元数 输出
        '''
        super().__init__()  # 第一句话,调用父类的构造函数
        self.net = nn.Sequential(
            nn.Linear(state_dim, mid_dim),
            nn.ELU(),
            nn.Linear(mid_dim, mid_dim),
            nn.ELU(),
            nn.Linear(mid_dim, mid_dim),
            nn.ELU(),
            nn.Linear(mid_dim, action_dim)
        )

    def forward(self, state):
        return self.net(state)  # 计算Q-value  直接返回action-dim的张量

主函数【注:该代码无法运行,目前用于学习流程】

if __name__ == '__main__':
    '''
    听说 pytorch的训练需要自己写?
    好吧 学习ing
    '''
    # 初始化 模型的类
    net = QNet(13, 7, 8)  # 输入13维向量 隐藏层7维 输出8维向量
    # 选择 损失函数和优化器
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(net.parameters(),lr=1e-4)

    # 可以开始训练了
    state = torch.Tensor([20,15,14,12,20.00,50,7,4.98,0.4,0.8,0.1,0,5])
    # y是目标值 
    for t in range(500):
        y_pred = net(state)  # 必须传入tensor 
        loss = criterion(y_pred,y)   # 计算损失函数 但是强化学习没有y呀???疑惑??
        optimizer.zero_grad() # 梯度置零
        loss.backward()
        optimizer.step()

总结:使用pytorch训练强化学习算法DQN模型
主训练过程

max_episodes
max_steps

for episode in max_episodes:
	state = get_state() # 获得初始状态
	for t in max_steps:
		action = choose_action(state) #根据当前状态选择动作
		_,reward,done,_ = env.step(action)  # 执行动作 获得奖励
		next_state = get_state() # 观察获得新状态
		memory.push(state,action,reward,next_state)  # 将transition存入经验缓冲池
		optimize_model()  # 优化模型
	if episode%target_update == 0: # 如果到了target_net更新的轮次,更新target_net
		target_net.update()  

其中optimize_model() 是本次记录的重点 pytorch构建的模型是如何优化更新的呢
经验缓冲池功能:存储经验,随机采样,
agent功能:choose_action() target_net.update()
环境功能:

def optimize_model():
	'''
	step1:首先从经验缓冲池中进行随机采样,将其拼接??
	'''
	if len(memory)<batch_size:
		return 
	transitions = memeory.sample()
	# 拼接操作  torch.cat()
	# 计算Q(st,a)  得到采取行动的列
 	state_action_values = policy_net(state_batch).gather(1,action_batch)
	# 计算下一个状态
	next_state_values = target_net(state).max(1)[0].detach()
	# 计算期望q值
	excepted_state_action_value = (next_state_values * gamma) + reward_batch
	
	# 计算loss
	loss = torch.nn.functional.smooth_l1_loss(state_action_values ,excepted_state_action_value )
	
	# 优化模型
	optimizer.zero_grad()
	loss.backward()
	
	
上一篇:09Oracle Database 数据表数据插入,更新,删除


下一篇:Redux 入门学习(1)