DL之NN:NN算法(本地数据集50000张训练集图片)进阶优化之三种参数改进,进一步提高手写数字图片识别的准确率

思路设计


首先,改变之一:


先在初始化权重的部分,采取一种更为好的随机初始化方法,我们依旧保持正态分布的均值不变,只对标准差进行改动,


初始化权重改变前,


def large_weight_initializer(self):  

       self.biases = [np.random.randn(y, 1) for y in self.sizes[1:]]

       self.weights = [np.random.randn(y, x)  for x, y in zip(self.sizes[:-1], self.sizes[1:])]

初始化权重改变后,


   def default_weight_initializer(self):

       self.biases = [np.random.randn(y, 1) for y in self.sizes[1:]]

       self.weights = [np.random.randn(y, x)/np.sqrt(x)  for x, y in zip(self.sizes[:-1], self.sizes[1:])]

改变之二:


为了减少Overfitting,降低数据局部噪音影响,将原先的目标函数由 quadratic cost 改为 cross-enrtopy cost


class CrossEntropyCost(object):

   def fn(a, y):

       return np.sum(np.nan_to_num(-y*np.log(a)-(1-y)*np.log(1-a)))

   def delta(z, a, y):

       return (a-y)

改变之三:


将S函数改为Softmax函数


class SoftmaxLayer(object):

   def __init__(self, n_in, n_out, p_dropout=0.0):

       self.n_in = n_in

       self.n_out = n_out

       self.p_dropout = p_dropout

       self.w = theano.shared(

           np.zeros((n_in, n_out), dtype=theano.config.floatX),

           name='w', borrow=True)

       self.b = theano.shared(

           np.zeros((n_out,), dtype=theano.config.floatX),

           name='b', borrow=True)

       self.params = [self.w, self.b]

   def set_inpt(self, inpt, inpt_dropout, mini_batch_size):

       self.inpt = inpt.reshape((mini_batch_size, self.n_in))

       self.output = softmax((1-self.p_dropout)*T.dot(self.inpt, self.w) + self.b)

       self.y_out = T.argmax(self.output, axis=1)

       self.inpt_dropout = dropout_layer(

           inpt_dropout.reshape((mini_batch_size, self.n_in)), self.p_dropout)

       self.output_dropout = softmax(T.dot(self.inpt_dropout, self.w) + self.b)

   def cost(self, net):

       "Return the log-likelihood cost."

       return -T.mean(T.log(self.output_dropout)[T.arange(net.y.shape[0]), net.y])

   def accuracy(self, y):

       "Return the accuracy for the mini-batch."

       return T.mean(T.eq(y, self.y_out))


上一篇:linux操作系统


下一篇:SQL 查询字段为值不为空