第二次作业 9.3

第二次作业 9.3

import numpy as np
import matplotlib.pyplot as plt
theta0 = np.array([th2, th3])
theta1 = np.random.random()
theta2 = np.random.random()
theta3 = np.random.random()
alpha = 0.0000001
x = np.array([[2104, 3], [1600, 3], [2400, 3], [1416, 2], [3000, 4]])
t = np.array([400, 330, 369, 232, 540])

eps = 1e-4
e0 = 150
e1 = 150
e2 = 150
i = 0
while e0 >= eps or e1 >= eps or e2 >= eps:
    e0 = (np.sum((x[i, 0:1] * th0[i])) + th1 * 1 - t[i])
    e1 = e0 * x[i, 0]
    e2 = e0 * x[i, 1]
    th1 = theta1 - alpha * theta0
    th2 = th2 - a * e1
    th3 = th3 - a * e2
    i +=1
print(th1, th2, th3)

0.7560097883036987 0.26594089210910343 0.07149837926879081

上一篇:Eclipse没有权限操作rt.jar包中的sun包,导致sun.net.ftp.FtpClient引用报错


下一篇:监督学习算法