代码
# -*- coding: utf-8 -*-
"""
Created on Wed Feb 23 20:37:01 2022
@author: koneko
"""
import numpy as np
import matplotlib.pyplot as plt
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def mean_squared_error(y, t):
return 0.5 * np.sum((y-t)**2)
class Sigmoid:
def __init__(self):
self.out = None
def forward(self, x):
out = sigmoid(x)
self.out = out
return out
def backward(self, dout):
dx = dout * (1.0 - self.out) * self.out
return dx
x = np.linspace(-np.pi, np.pi, 1000)
y = np.sin(x)
plt.plot(x,y)
x = x.reshape(1, x.size)
y = y.reshape(1, y.size)
# 初始化权重
W1 = np.random.randn(3,1)
b1 = np.random.randn(3,1)
W2 = np.random.randn(2,3)
b2 = np.random.randn(2,1)
W3 = np.random.randn(1,2)
b3 = np.random.randn(1,1)
sig1 = Sigmoid()
sig2 = Sigmoid()
lr = 0.001
for i in range(30000):
a1 = W1 @ x + b1
c1 = sig1.forward(a1)
a2 = W2 @ c1 + b2
c2 = sig2.forward(a2)
y_pred = W3 @ c2 + b3
#y_pred = W2 @ c1 + b2
Loss = mean_squared_error(y, y_pred)
print(f"Loss[{i}]: {Loss}")
dy_pred = y_pred - y
dc2 = W3.T @ dy_pred
da2 = sig2.backward(dc2)
dc1 = W2.T @ da2
da1 = sig1.backward(dc1)
# 计算Loss对各层参数的偏导数
dW3 = dy_pred @ c2.T
db3 = np.sum(dy_pred)
dW2 = da2 @ c1.T
db2 = np.sum(da2, axis=1)
db2 = db2.reshape(db2.size, 1)
dW1 = da1 @ x.T
db1 = np.sum(da1, axis=1)
db1 = db1.reshape(db1.size, 1)
W3 -= lr*dW3
b3 -= lr*db3
W2 -= lr*dW2
b2 -= lr*db2
W1 -= lr*dW1
b1 -= lr*db1
if i % 100 == 99:
plt.cla()
plt.plot(x.T,y.T)
plt.plot(x.T,y_pred.T)
效果