目录
一、基础理论
这里只写公式,更加详细的可以看前篇CSDN
前向传递(得到输出y)
(b是偏置)
反向传递(更新权重w)
更新权重:
二、实现多数据分类
1、设置初始参数
# 1、设置初始参数
# 输入 #每一行对应一个标签
x = np.array([[1, 0, 2], #(0,2)坐标
[1, 1, 3], #(1,3)坐标
[1, 4, 2], #(4,2)坐标
[1, 5, 1] #(5,1)坐标
])
# 初始权重(0~1的随机生产数)
w = np.random.random([3, 1]) #3行1列
# 偏置
b = 1
# 标签(正确标签,训练结束的目标)
true = np.array([ #每一行对应一个标签
[-1],
[-1],
[ 1],
[ 1]
])
# 学习率
lr = 1
2、训练
注:dot是矩阵点乘,矩阵点乘要注意前者列与后者行是否相等。
# 开始训练
for i in range(100):
# 2、正向传播:计算输出y
y = np.sign(np.dot(x, w)).astype(int) #dot:矩阵乘法(x列==w行)
print('epoch:', i) #迭代次数
print('weight:', w) #权重
print(y) #标签
# 训练成功
if (y == true).all():
print('训练成功!')
print('y = ', y)
break
# 3、反向传播:更新权重
# 训练失败(更新权重)
else:
w += lr * np.dot(x.T, true-y)/x.shape[0]
# w:权重 lr:学习率 np.dot:矩阵点乘 true-y:差 x.T:x的转置 x.shape[0]:行数
3、画图
3-1、画点
# 4、画图
# 4-1、画点
# 正样本坐标
x1, y1 = [0,1], [2,3] #坐标:(0,2), (1,3)
# 负样本坐标
x2, y2 = [4,5], [2,1] #坐标:(4,2),(5,1)
plt.scatter(x1, y1, c='b') #点:(x1,y1)坐标,blue颜色
plt.scatter(x2, y2, c='g') #点:(x2,y2)坐标,green颜色
3-2、画线段
# 4-2、画线段
# 定义线段两点的x坐标
line_x = (0, 6)
# 计算线性方程的k和d:
# w0*x0+w1*x1+w2*x2 = 0
# 把x1、x2分别看作:x、y
# 可以得到:w0 + w1*x + w2*y = 0 --> y = -w1/w2*x + -w0/w2 --> k=-w1/w2, d=-w0/w2
# 线段两端点的y坐标
k = -w[1]/w[2] #斜率
d = -w[0]/w[2] #截距
line_y = k * line_x + d #y坐标
#画线段(通过两个点)
plt.plot(line_x, line_y, 'r') #r:red
plt.show()
总代码
# 手写单层感知器(多数据分类)
import numpy as np
import matplotlib.pyplot as plt
# 1、设置初始参数
# 输入 #每一行对应一个标签
x = np.array([[1, 0, 2], #(0,2)坐标
[1, 1, 3], #(1,3)坐标
[1, 4, 2], #(4,2)坐标
[1, 5, 1] #(5,1)坐标
])
# 初始权重(0~1的随机生产数)
w = np.random.random([3, 1]) #3行1列
# 偏置
b = 1
# 标签(正确标签,训练结束的目标)
true = np.array([ #每一行对应一个标签
[-1],
[-1],
[ 1],
[ 1]
])
# 学习率
lr = 1
# 开始训练
for i in range(100):
# 2、正向传播:计算输出y
y = np.sign(np.dot(x, w)).astype(int) #dot:矩阵乘法(x列==w行)
print('epoch:', i) #迭代次数
print('weight:', w) #权重
print(y) #标签
# 训练成功
if (y == true).all():
print('训练成功!')
print('y = ', y)
break
# 3、反向传播:更新权重
# 训练失败(更新权重)
else:
w += lr * np.dot(x.T, true-y)/x.shape[0]
# w:权重 lr:学习率 np.dot:矩阵点乘 true-y:差 x.T:x的转置 x.shape[0]:行数
# 4、画图
# 4-1、画点
# 正样本坐标
x1, y1 = [0,1], [2,3] #坐标:(0,2), (1,3)
# 负样本坐标
x2, y2 = [4,5], [2,1] #坐标:(4,2),(5,1)
plt.scatter(x1, y1, c='b') #点:(x1,y1)坐标,blue颜色
plt.scatter(x2, y2, c='g') #点:(x2,y2)坐标,green颜色
# 4-2、画线段
# 定义线段两点的x坐标
line_x = (0, 6)
# 计算线性方程的k和d:
# w0*x0+w1*x1+w2*x2 = 0
# 把x1、x2分别看作:x、y
# 可以得到:w0 + w1*x + w2*y = 0 --> y = -w1/w2*x + -w0/w2 --> k=-w1/w2, d=-w0/w2
# 线段两端点的y坐标
k = -w[1]/w[2] #斜率
d = -w[0]/w[2] #截距
line_y = k * line_x + d #y坐标
#画线段(通过两个点)
plt.plot(line_x, line_y, 'r') #r:red
plt.show()