'''
Author: huajia
Date: 2021-11-22 14:57:01
LastEditors: huajia
LastEditTime: 2021-12-01 12:53:56
Description: 略略略
'''
import numpy as np
import math
import matplotlib.pyplot as plt
def a(x, num, r):
length=x.shape[0]
w1 = np.random.randn(1,length)*0.01
b1 = np.zeros(shape=(w1.shape[0], 1))
w2 = np.random.randn(w1.shape[0], 1)*0.01
b2 = np.zeros(shape=(w2.shape[0], 1))
Y = np.array([[1]])
less = []
for i in range(num):
z1 = np.dot(w1,x)+b1
a1 = np.maximum(z1, 0)
z2 = np.dot(w2,a1)+b2
a2 = 1/(1+np.exp(-z2))
delta = 1e-10
l = -(np.sum((Y*np.log(a2+delta)+(1-Y)*np.log(1-a2+delta))))
if(np.isnan(l)):
exit()
less.append(l)
if(i%1000==0):
print('第%d轮:' %(i),'w1:', w1, 'w2:', w2, 'b1:', b1, 'b2:', b2, 'z1:',
z1, 'z2:', z2, 'a1:', a1, 'a2:', a2, 'less:', l)
dz2 = a2-Y
dw2 = np.dot(dz2,a1.T)/length
db2 = np.sum(dz2,axis=1,keepdims=True)/length
dz1 = np.multiply(np.dot(w2.T, dz2), 1 - np.power(a1, 2))
dw1 = (1/length)*np.dot(dz1,x.T)
db1 = np.sum(dz1,axis=1,keepdims=True)/length
w1 -= r*dw1
b1 -= r*db1
w2 -= r*dw2
b2 -= r*db2
plt.plot(np.arange(0, num), less, label="less")
plt.legend()
plt.show()
x=np.array([[-10000],[1]])
a(x, 10000, 10)
exit()