sklearn 特征降维利器 —— PCA TSNE

同为降维工具,二者的主要区别在于,

  • 所在的包不同(也即机制和原理不同)
    • from sklearn.decomposition import PCA
    • from sklearn.manifold import TSNE
  • 因为原理不同,导致,tsne 保留下的属性信息,更具代表性,也即最能体现样本间的差异;
  • TSNE 运行极慢,PCA 则相对较快;

因此更为一般的处理,尤其在展示(可视化)高维数据时,常常先用 PCA 进行降维,再使用 tsne:

data_pca = PCA(n_components=50).fit_transform(data)
data_pca_tsne = TSNE(n_components=2).fit_transform(data_pca)
  • 1
  • 2

t-SNE(t-distribution Stochastic Neighbor Embedding)是目前最为流行的高维数据的降维算法。

t-SNE 成立的前提基于这样的一个假设:我们现实世界观察到的数据集,都在本质上有一种低维的特性(low intrinsic dimensionality),尽管它们嵌入在高维空间中,甚至可以说,高维数据经过降维后,在低维状态下,更能显现其本质特性,这其实也是流形学习(Manifold Learning)的基本思想。

原始论文请见,论文链接(pdf)

1. sklearn 仿真

  • import 必要的库;

    import numpy as np
    from numpy import linalg
    from numpy.linalg import norm
    from scipy.spatial.distance import squareform, pdist
    
    # We import sklearn.
    import sklearn
    from sklearn.manifold import TSNE
    from sklearn.datasets import load_digits
    from sklearn.preprocessing import scale
    
    # We'll hack a bit with the t-SNE code in sklearn 0.15.2.
    from sklearn.metrics.pairwise import pairwise_distances
    from sklearn.manifold.t_sne import (_joint_probabilities,
                                        _kl_divergence)
    from sklearn.utils.extmath import _ravel
    # Random state.
    RS = 20150101
    
    # We'll use matplotlib for graphics.
    import matplotlib.pyplot as plt
    import matplotlib.patheffects as PathEffects
    import matplotlib
    %matplotlib inline
    
    # We import seaborn to make nice plots.
    import seaborn as sns
    sns.set_style('darkgrid')
    sns.set_palette('muted')
    sns.set_context("notebook", font_scale=1.5,
                    rc={"lines.linewidth": 2.5})
    
    # We'll generate an animation with matplotlib and moviepy.
    from moviepy.video.io.bindings import mplfig_to_npimage
    import moviepy.editor as mpy
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
  • 加载数据集

    digits = load_digits()
    		# digits.data.shape ⇒ (1797L, 64L)
    
    • 1
    • 2
  • 调用 sklearn 工具箱中的 t-SNE 类

    X = np.vstack([digits.data[digits.target==i]
                   for i in range(10)])
    y = np.hstack([digits.target[digits.target==i]
                   for i in range(10)])
    digits_proj = TSNE(random_state=RS).fit_transform(X)
    		# digits_proj:(1797L, 2L),ndarray 类型
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
  • 可视化

    def scatter(x, colors):
        # We choose a color palette with seaborn.
        palette = np.array(sns.color_palette("hls", 10))
    
        # We create a scatter plot.
        f = plt.figure(figsize=(8, 8))
        ax = plt.subplot(aspect='equal')
        sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40,
                        c=palette[colors.astype(np.int)])
        plt.xlim(-25, 25)
        plt.ylim(-25, 25)
        ax.axis('off')
        ax.axis('tight')
    
        # We add the labels for each digit.
        txts = []
        for i in range(10):
            # Position of each label.
            xtext, ytext = np.median(x[colors == i, :], axis=0)
            txt = ax.text(xtext, ytext, str(i), fontsize=24)
            txt.set_path_effects([
                PathEffects.Stroke(linewidth=5, foreground="w"),
                PathEffects.Normal()])
            txts.append(txt)
    
        return f, ax, sc, txts
    scatter(digits_proj, y)
    plt.savefig('images/digits_tsne-generated.png', dpi=120)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

An illustrated introduction to the t-SNE algorithm

再分享一下我老师大神的人工智能教程吧。零基础!通俗易懂!风趣幽默!还带黄段子!希望你也加入到我们人工智能的队伍中来!http://www.captainbed.net

上一篇:TSNE/分析两个数据的分布


下一篇:m4