DL之DNN优化技术:DNN优化器的参数优化—更新参数的四种最优化方法(SGD/Momentum/AdaGrad/Adam)的案例理解、图表可视化比较

优化器案例理解

输出结果

DL之DNN优化技术:DNN优化器的参数优化—更新参数的四种最优化方法(SGD/Momentum/AdaGrad/Adam)的案例理解、图表可视化比较

DL之DNN优化技术:DNN优化器的参数优化—更新参数的四种最优化方法(SGD/Momentum/AdaGrad/Adam)的案例理解、图表可视化比较

DL之DNN优化技术:DNN优化器的参数优化—更新参数的四种最优化方法(SGD/Momentum/AdaGrad/Adam)的案例理解、图表可视化比较

设计思路

DL之DNN优化技术:DNN优化器的参数优化—更新参数的四种最优化方法(SGD/Momentum/AdaGrad/Adam)的案例理解、图表可视化比较


核心代码

#T1、SGD算法

class SGD:

'……'

   def update(self, params, grads):

       for key in params.keys():

           params[key] -= self.lr * grads[key]

#T2、Momentum算法

import numpy as np

class Momentum:

'……'

   def update(self, params, grads):

       if self.v is None:

           self.v = {}

           for key, val in params.items():                                

               self.v[key] = np.zeros_like(val)

       for key in params.keys():

           self.v[key] = self.momentum*self.v[key] - self.lr*grads[key]

           params[key] += self.v[key]

#T3、AdaGrad算法

'……'

       

   def update(self, params, grads):

       if self.h is None:

           self.h = {}

           for key, val in params.items():

               self.h[key] = np.zeros_like(val)

       for key in params.keys():

           self.h[key] += grads[key] * grads[key]

           params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)  

#T4、Adam算法

'……'

       

   def update(self, params, grads):

       if self.m is None:

           self.m, self.v = {}, {}

           for key, val in params.items():

               self.m[key] = np.zeros_like(val)

               self.v[key] = np.zeros_like(val)

       self.iter += 1

       lr_t  = self.lr * np.sqrt(1.0 - self.beta2**self.iter) / (1.0 - self.beta1**self.iter)        

       

       for key in params.keys():

           self.m[key] += (1 - self.beta1) * (grads[key] - self.m[key])

           self.v[key] += (1 - self.beta2) * (grads[key]**2 - self.v[key])

           

           params[key] -= lr_t * self.m[key] / (np.sqrt(self.v[key]) + 1e-7)


上一篇:SQL Server-存储过程(Procedure),带入参数和出参数


下一篇:Android开发重要参考资料