LDA及SVM算法sklearn学习记录
一.LDA算法
1.导入所需库
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification #生成测试用数据集
2.生成并划分数据集
#生成数据集
data,target=make_classification()
print(data)
print(target)
#按照8:2比例划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.2)
[[ 0.1527 -0.8111 -0.3543 ... -0.3221 -1.1127 -0.0102]
[-0.769 -0.4404 -0.2247 ... -0.0384 -0.3103 -1.0622]
[ 0.7917 -0.8278 1.6116 ... 0.3636 0.2646 1.3723]
...
[-2.3984 1.3397 -1.6188 ... 1.1403 2.0689 -1.2399]
[-0.6706 -0.0643 -0.2727 ... -2.7449 -0.7093 -0.2345]
[ 1.0208 -0.9271 -0.9315 ... 2.1307 0.6394 -1.7078]]
[1 1 0 1 0 1 0 0 1 0 1 1 0 1 0 0 1 0 1 0 0 1 1 0 0 1 1 1 1 1 1 0 1 1 0 0 0
1 1 1 0 0 1 0 1 1 0 0 0 1 0 1 1 1 0 0 0 0 0 1 0 1 0 0 1 0 1 0 1 1 0 1 0 0
0 0 0 0 0 0 1 1 0 0 1 1 1 0 0 0 1 0 1 1 1 1 0 0 1 1]
3.拟合并测试
ir = LDA()
#拟合
lr.fit(x_train, y_train)
lr.score(x_test,y_test)
0.85
二.SVM算法
1.导入所需库
from sklearn.datasets import make_moons #生成月亮数据集
from sklearn.model_selection import train_test_split
from sklearn import svm #svm
import matplotlib.pyplot as plt #画图展示数据集
2.生成并划分数据集
data,target=make_moons(noise=0.3);
plt.plot(data[:,0][target==0],data[:,1][target==0],"rs")
plt.plot(data[:,0][target==1],data[:,1][target==1],"bs")
plt.grid(True)
plt.show()
x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.2)
3.svm算法通过不同内核实现拟合
相关函数定义如下
class sklearn.svm.SVC(*, C=1.0, kernel='rbf', degree=3, gamma='scale', coef0=0.0, shrinking=True, probability=False, tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=- 1, decision_function_shape='ovr', break_ties=False, random_state=None)
其中kernel为使用的内核,默认为rbf内核,即高斯核
res_linear=svm.SVC(kernel="linear")#线性核
res_linear.fit(x_train,y_train)
res_linear.score(x_test,y_test)
0.8
res_poly=svm.SVC(kernel="poly",degree=3)#多项式核
res_poly.fit(x_train,y_train)
res_poly.score(x_test,y_test)
0.65
res_rbf=svm.SVC()#rbf内核
res_rbf.fit(x_train,y_train)
res_rbf.score(x_test,y_test)
0.95
当使用线性核时,r平方为0.8
当使用多项式核时,r平方为0.65
当使用高斯核时,r平方为0.95