fit a complex function using 101 RBF basis functions with Linear Regression

import numpy as np
import matplotlib.pyplot as plt


def rbf_tut1_q3(xx, kk, hh):
    """Evaluate RBF kk with bandwidth hh on points xx (shape N,)"""
    center = ((kk - 51) * hh) / np.sqrt(2)
    phi = np.exp((-(xx - center) ** 2) / hh ** 2)
    return phi  # shape (N,)


N = 10
N_real= 1000
K = 101
kks = np.arange(1, K+1, 1)
# print(kks)
hh = 0.2
xx = np.arange(-1, 1, 2 / N)
# yy = np.arange(1, -1, 2 / N)
yy = xx**7+xx**5+xx*3+xx

# real values
xxx = np.arange(-1, 1, 2 / N_real)
yyy = xxx**7+xxx**5+xxx*3+xxx

for kk in kks:
    inputs = rbf_tut1_q3(xx, kk, hh)
# A = np.vstack([inputs, np.ones(len(inputs))]).T
X = []
for i in range(K):
    X.append(rbf_tut1_q3(xx, kks[i], hh))
X = np.array(X)
X = X.T

print(X.shape)
# print(np.linalg.lstsq(X, yy, rcond=None))
W = np.linalg.lstsq(X, yy, rcond=None)[0]
#print(W)

# plt.plot(xx, rbf_tut1_q3(xx, kks[20], hh)+rbf_tut1_q3(xx, kks[30], hh)+rbf_tut1_q3(xx, kks[40] , hh))
print(np.matmul(X, W).shape)
plt.plot(xxx, yyy, c = 'r')
plt.plot(xx, np.matmul(X, W), c='y')
plt.scatter(xx, yy, marker = 'o')


plt.show()
# plt.clf()
# plt.plot(xx, W*xx)
# plt.show()

fit a complex function using 101 RBF basis functions with Linear Regression

上一篇:C#的$符号


下一篇:acwing-1088旅行问题