线性判别准则与线性分类编程实践

一、线性判别分析介绍

线性判别分析(Linear Discriminant Analysis,简称 L D A LDALDA)是一种经典的线性学习方法,亦称"Fisher 判别分析"。
线性判别准则与线性分类编程实践

线性判别分析思想:给定训练样本集,设法将样例投影到一条直线上。使得同类样例的投影点尽可能接近、异类样例的投影点尽可能远;在对新样本进行分类时,将其投影到该直线上,再根据投影点的位置来确定新样本的类别。 

二、线性判别分析原理

线性判别准则与线性分类编程实践

1. 类内散度矩阵(within-class scatter matrix)

类内散度矩阵用来判断同类样例的投影点之间的距离。

线性判别准则与线性分类编程实践

 

2. 类间散度矩阵(between-class scatter matrix)

  类间散度矩阵用来判断异类样例的投影点之间的距离。

线性判别准则与线性分类编程实践

3. 广义瑞利商(generalized Rayleigh quotiet)

广义瑞利商(generalized Rayleigh quotiet)就是 L D A LDALDA欲最大化的目标,使用类内散度矩阵和类间散度矩阵将最大化目标改写为:线性判别准则与线性分类编程实践

LDA可从贝叶斯决策理论的角度来阐释,并可证明,当两类数据同先验、满足高斯分布且协方差相等时,LDA可达到最优分类。
 

三、sklearn库实现线性判别分析LDA

  1. 数据生成
    #生成200个三个维度样本
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    from sklearn.datasets import make_classification
    x, y = make_classification(n_samples=200, n_features=2, n_redundant=0, n_classes=2, n_informative=2,n_clusters_per_class=2,class_sep =1, random_state =0)
    fig = plt.figure()
    plt.scatter(x[:, 0], x[:, 1], c=y)
    

    线性判别准则与线性分类编程实践

  2. 数据处理
    #设置分类平滑度
    h = .01
    #设置X和Y的边界值
    x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
    y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
    
    #使用meshgrid函数返回X和Y两个坐标向量矩阵
    xx, yy = np.meshgrid(np.arange(x_min, x_max,h), np.arange(y_min, y_max,h))
    Z = lda.predict(np.c_[xx.ravel(), yy.ravel()])
    

  3. 数据集划分
    from sklearn.model_selection import train_test_split
    x_train,x_test,y_train,y_test = train_test_split(x, y, random_state=33, test_size=0.25)
    

  4. LDA分类
    #使用LDA进行降维
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 
    from sklearn.linear_model import LogisticRegression
    lda = LinearDiscriminantAnalysis(n_components=1)
    
    x_train_lda = lda.fit_transform(x_train, y_train)  # LDA是有监督方法,需要用到标签
    x_test_lda = lda.fit_transform(x_test, y_test)   # 预测时候特征向量正负问题,乘-1反转镜像
    

  5. 绘制训练集分类图像
    #设置colormap颜色
    cm_bright = ListedColormap(['#D9E021', '#0D8ECF'])
    #绘制数据点
    plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap=cm_bright)
    plt.title('Linear Discriminant Analysis Classifiers')
    plt.axis('tight')
    plt.show()
    

    线性判别准则与线性分类编程实践

  6. 绘制测试集分类图
    plt.title('Linear Discriminant Analysis Classifiers')
    plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, cmap=cm_bright)
    plt.show()
    

    线性判别准则与线性分类编程实践

     

四、总结

   LDA算法既可以用来降维,也可以用来分类,但是目前来说,主要还是用于降维,和PCA类似,LDA降维基本也不用调参,只需要指定降维到的维数即可。

五、参考

【机器学习】机器学习之线性判别分析(LDA)_YangMax1的博客-CSDN博客

上一篇:【线代&NumPy】第十二章 - 矩阵对角化课后练习 | 特征值分解 | 散布矩阵 | 降维方法LDA | 简述并提供代码


下一篇:2021-10-20 LDA降维python代码应用详解