再缕一下正反传播,有点乱之前(麻了)


'''
Author: huajia
Date: 2021-11-22 14:57:01
LastEditors: huajia
LastEditTime: 2021-12-01 12:53:56
Description: 略略略
'''

import numpy as np
import math
import matplotlib.pyplot as plt


def a(x, num, r):
    length=x.shape[0]
    w1 = np.random.randn(1,length)*0.01
    b1 = np.zeros(shape=(w1.shape[0], 1))
    w2 = np.random.randn(w1.shape[0], 1)*0.01
    b2 = np.zeros(shape=(w2.shape[0], 1))
    Y = np.array([[1]])
    less = []
    for i in range(num):
        z1 = np.dot(w1,x)+b1
        a1 = np.maximum(z1, 0)
        z2 = np.dot(w2,a1)+b2
        a2 = 1/(1+np.exp(-z2))
        delta = 1e-10
        l = -(np.sum((Y*np.log(a2+delta)+(1-Y)*np.log(1-a2+delta))))
        if(np.isnan(l)):
            exit()
        less.append(l)
        if(i%1000==0):
            print('第%d轮:' %(i),'w1:', w1, 'w2:', w2, 'b1:', b1, 'b2:', b2, 'z1:',
              z1, 'z2:', z2, 'a1:', a1, 'a2:', a2, 'less:', l)
        dz2 = a2-Y
        dw2 = np.dot(dz2,a1.T)/length
        db2 = np.sum(dz2,axis=1,keepdims=True)/length
        dz1 = np.multiply(np.dot(w2.T, dz2), 1 - np.power(a1, 2))
        dw1 = (1/length)*np.dot(dz1,x.T)
        db1 = np.sum(dz1,axis=1,keepdims=True)/length
        w1 -= r*dw1
        b1 -= r*db1
        w2 -= r*dw2
        b2 -= r*db2
    plt.plot(np.arange(0, num), less, label="less")
    plt.legend()
    plt.show()

x=np.array([[-10000],[1]])
a(x, 10000, 10)
exit()

上一篇:L1、L2正则化的理解


下一篇:正版三国杀