scikit-learn:多项式回归

在一元回归分析中,如果自变量x和因变量y之间的关系是非线性的,在找不到合适的函数曲线来拟合的情况下,可以采用一元多项式回归。如果自变量不止一个,则采用多元多项式回归。
多项式回归可以处理相当一类非线性问题,因为任意函数都可以分段,用多项式来逼近。

使用的假设函数是一元一次方程,也就是二维平面上的一条直线。但是很多时候可能会遇到直线方程无法很好的拟合数据的情况,这个时候可以尝试使用多项式回归。多项式回归中,加入了特征的更高次方(例如平方项或立方项),也相当于增加了模型的*度,用来捕获数据中非线性的变化。添加高阶项的时候,也增加了模型的复杂度。随着模型复杂度的升高,模型的容量以及拟合数据的能力增加,可以进一步降低训练误差,但导致过拟合的风险也随之增加。

scikit-learn:多项式回归

多项式回归的一般形式

scikit-learn:多项式回归
scikit-learn:多项式回归
scikit-learn:多项式回归
损失函数
scikit-learn:多项式回归

1. 多项式回归的实现

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

scikit-learn:多项式回归

1.1 直线方程拟合

data = np.array([[ -2.95507616,  10.94533252],
       [ -0.44226119,   2.96705822],
       [ -2.13294087,   6.57336839],
       [  1.84990823,   5.44244467],
       [  0.35139795,   2.83533936],
       [ -1.77443098,   5.6800407 ],
       [ -1.8657203 ,   6.34470814],
       [  1.61526823,   4.77833358],
       [ -2.38043687,   8.51887713],
       [ -1.40513866,   4.18262786]])
m = data.shape[0]  # 样本大小
X = data[:, 0].reshape(-1, 1)  # 将array转换成矩阵
y = data[:, 1].reshape(-1, 1)
plt.plot(X, y, "b.")
plt.xlabel('X')
plt.ylabel('y')
plt.show()

scikit-learn:多项式回归
#下面先用直线方程拟合上面的数据点:

lin_reg = LinearRegression()
lin_reg.fit(X, y)
print(lin_reg.intercept_, lin_reg.coef_)  # [ 4.97857827] [[-0.92810463]]

X_plot = np.linspace(-3, 3, 1000).reshape(-1, 1)
y_plot = np.dot(X_plot, lin_reg.coef_.T) + lin_reg.intercept_
plt.plot(X_plot, y_plot, 'r-')
plt.plot(X, y, 'b.')
plt.xlabel('X')
plt.ylabel('y')
plt.savefig('regu-2.png', dpi=200)

scikit-learn:多项式回归

1.2 使用多项式方程

为了拟合2次方程,需要有特征x^2的数据,这里可以使用函数"PolynomialFeatures"来获得:

poly_features = PolynomialFeatures(degree=2, include_bias=False)
X_poly = poly_features.fit_transform(X)
print(X_poly)
lin_reg = LinearRegression()
lin_reg.fit(X_poly, y)
print(lin_reg.intercept_, lin_reg.coef_)  # [ 2.60996757] [[-0.12759678  0.9144504 ]]

X_plot = np.linspace(-3, 3, 1000).reshape(-1, 1)
X_plot_poly = poly_features.fit_transform(X_plot)
y_plot = np.dot(X_plot_poly, lin_reg.coef_.T) + lin_reg.intercept_
plt.plot(X_plot, y_plot, 'r-')
plt.plot(X, y, 'b.')
plt.show()

scikit-learn:多项式回归
得到了训练后的参数,即多项式方程为h=−0.13x+0.91x^2+2.61

上一篇:洛谷P3177 树上染色


下一篇:全栈项目|小书架|服务器开发-Koa2 参数校验处理