SparkML机器学习之聚类(K-Means、GMM、LDA)

聚类的概念

聚类就是对大量未知标注(无监督)的数据集,按照数据之间的相似度,将N个对象的数据集划分为K个划分(K个簇),使类别内的数据相似度较大,而类别间的数据相似较小。比如用户画像就是一种很常见的聚类算法的应用场景,基于用户行为特征或者元数据将用户分成不同的类。

常见聚类以及原理

K-means算法

也被称为k-均值,是一种最广泛使用的聚类算法,也是其他聚类算法的基础。来看下它的原理:

既然要划分为k个簇,因此算法首先随机的选择了k个对象,每个对象初始的代表了一个簇的中心。其他的对象就去计算它们与这k个中心的距离(这里距离就是相似度),离哪个最近就把它们丢给哪个簇。第一轮聚类结束后,重新计算每个簇的平均值,将N中全部元素按照新的中心重新聚类。这个过程不断的反复,使得每一个改进之后的划分都比之前效果好,直到准则函数收敛(聚类结果不再变化)。
SparkML机器学习之聚类(K-Means、GMM、LDA)

举个例子,10个男生我要分为2个类,我先随机选择2个男生a b,那么对于c,我分别计算他与a b 的相似度,假设和a最相似,那么我丢到a这个类里,经过第一轮,得到(1类){a,c,e,f,i}(2类) {b,d,g,h,j},现在我重新计算1类和2类的均值,假设为1.5和5.1,那么对这10个男生再判断它们与1.5 5.1的距离,和哪个相近就去哪边。第二轮迭代使得聚类更加精确。反复这样,直到某一次的迭代我发现不再有变化,那么就ok了。

但是K-means有些缺点,其一是受初始值影响较大。下面这张图很好的解释了这个缺点,人眼一看就能看出来,如果是分为4个聚类,应该这么分为左边那样的,如果用K-means结果会是右边这样,明显不对,所以说受初始值的影响会比较大。

SparkML机器学习之聚类(K-Means、GMM、LDA)

因为这个缺陷,所以有了Bisecting k-means(二分K均值)

Bisecting k-means(二分K均值)

主要是为了改进k-means算法随机选择初始中心的随机性造成聚类结果不确定性的问题,而Bisecting k-means算法受随机选择初始中心的影响比较小。
先将所有点作为一个簇,然后将该簇一分为二。之后选择其中一个簇【具有最大SSE值的一个簇】继续进行划分,二分这个簇以后得到的2个子簇,选择2个子簇的总SSE最小的划分方法,这样能够保证每次二分得到的2个簇是比较优的(也可能是最优的)。
SSE(Sum of Squared Error),也就是误差平方和,它计算的是拟合数据和原始数据对应点的误差的平方和,它是用来度量聚类效果的一个指标。SSE越接近于0,说明模型选择和拟合更好,数据预测也越成功。

上面讲的都是硬聚类,硬聚类即一定是属于某一个类,比如我有2个簇A和B,那么所有的对象要不属于A要不就属于B,不可能会有第三种可能。而软聚类,使用概率的方式,一个对象可以是60%属于A,40% 属于B,即不完全属于同一个分布,而是以不同的概率分属于不同的分布GMM(高斯混合模型)就是一种软聚类。

GMM(高斯混合模型)

它和K-Means的区别是,K-Means是算出每个数据点所属的簇,而GMM是计算出这些数据点分配到各个类别的概率
GMM算法步骤如下:
1.猜测有 K 个类别、即有K个高斯分布。
2.对每一个高斯分布赋均值 μ 和方差 Σ 。
3.对每一个样本,计算其在各个高斯分布下的概率。
SparkML机器学习之聚类(K-Means、GMM、LDA)

4.每一个样本对某高斯分布的贡献可以由其下的概率表示。并把该样本对该高斯分布的贡献作为权重来计算加权的均值和方差以替代其原本的均值和方差。
5.重复3~4直到每一个高斯分布的均值和方差收敛。

SparkML聚类

SparkML中主要聚类有以下几种:

  • K-means
  • Latent Dirichlet allocation (LDA)
  • Bisecting k-means
  • Gaussian Mixture Model (GMM)

KMeans

package ml.test
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.sql.SparkSession
/**
  * Created by liuyanling on 2018/3/24   
  */
object KMeansDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()
    val df = spark.read.format("libsvm").load("kmeans_data.txt")
    //setK设置要分为几个类 setSeed设置随机种子
    val kmeans = new KMeans().setK(3).setSeed(1L)
    //聚类模型
    val model = kmeans.fit(df)
    // 预测 即分配聚类中心
    model.transform(df).show(false)
    //聚类中心
    model.clusterCenters.foreach(println)
    // SSE误差平方和
    println("SSE:"+model.computeCost(df))
  }
}

输出结果为:

+-----+-------------------------+----------+
|label|features                 |prediction|
+-----+-------------------------+----------+
|0.0  |(3,[],[])                |1         |
|1.0  |(3,[0,1,2],[0.1,0.1,0.1])|1         |
|2.0  |(3,[0,1,2],[0.2,0.2,0.2])|2         |
|3.0  |(3,[0,1,2],[9.0,9.0,9.0])|0         |
|4.0  |(3,[0,1,2],[9.1,9.1,9.1])|0         |
|5.0  |(3,[0,1,2],[9.2,9.2,9.2])|0         |
+-----+-------------------------+----------+

[9.1,9.1,9.1]
[0.05,0.05,0.05]
[0.2,0.2,0.2]

SSE:0.07499999999994544

附:kmeans_data.txt

0 1:0.0 2:0.0 3:0.0
1 1:0.1 2:0.1 3:0.1
2 1:0.2 2:0.2 3:0.2
3 1:9.0 2:9.0 3:9.0
4 1:9.1 2:9.1 3:9.1
5 1:9.2 2:9.2 3:9.2

BisectingKMeans二分K均值

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()
    val df = spark.read.format("libsvm").load("kmeans_data.txt")
    val kmeans = new BisectingKMeans().setK(3).setSeed(1)
    //聚类模型
    val model = kmeans.fit(df)
    // 预测 即分配聚类中心
    model.transform(df).show(false)
    //聚类中心
    model.clusterCenters.foreach(println)
  }

输出结果为:

+-----+-------------------------+----------+
|label|features                 |prediction|
+-----+-------------------------+----------+
|0.0  |(3,[],[])                |0         |
|1.0  |(3,[0,1,2],[0.1,0.1,0.1])|0         |
|2.0  |(3,[0,1,2],[0.2,0.2,0.2])|0         |
|3.0  |(3,[0,1,2],[9.0,9.0,9.0])|1         |
|4.0  |(3,[0,1,2],[9.1,9.1,9.1])|1         |
|5.0  |(3,[0,1,2],[9.2,9.2,9.2])|2         |
+-----+-------------------------+----------+

[0.1,0.1,0.1]
[9.05,9.05,9.05]
[9.2,9.2,9.2]

可以发现,使用kmeans和BisectingKMeans,聚类结果一般是不一样的。

GMM高斯混合模型

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()
    val df = spark.read.format("libsvm").load("kmeans_data.txt")
    val gmm = new GaussianMixture().setK(3).setSeed(0)
    val model = gmm.fit(df)
    for (i <- 0 until model.getK) {
      //weight是各组成成分的权重
      //cov是样本协方差矩阵
      //mean是均值
      println(s"Gaussian $i:\nweight=${model.weights(i)}\n" +
        s"mu=${model.gaussians(i).mean}\nsigma=\n${model.gaussians(i).cov}\n")
    }
    //也可以使用这种方式输出:model.gaussiansDF.show(false)
  }

输出结果为:

Gaussian 0:
weight=0.28757327891867634
mu=[0.09999633894391767,0.09999633894391767,0.09999633894391767]
sigma=
0.006666589009587926  0.006666589009587926  0.006666589009587926  
0.006666589009587926  0.006666589009587926  0.006666589009587926  
0.006666589009587926  0.006666589009587926  0.006666589009587926  

Gaussian 1:
weight=0.2124267210813245
mu=[0.1000049561651758,0.1000049561651758,0.1000049561651758]
sigma=
0.006666771753107945  0.006666771753107945  0.006666771753107945  
0.006666771753107945  0.006666771753107945  0.006666771753107945  
0.006666771753107945  0.006666771753107945  0.006666771753107945  

Gaussian 2:
weight=0.49999999999999917
mu=[9.099999999999984,9.099999999999984,9.099999999999984]
sigma=
0.006666666666831146  0.006666666666831146  0.006666666666831146  
0.006666666666831146  0.006666666666831146  0.006666666666831146  
0.006666666666831146  0.006666666666831146  0.006666666666831146  

LDA主题模型

LDA是一个三层贝叶斯概率模型,包含词、主题和文档三层结构。
LDA可以用来生成一篇文档,生成时,每个词都是通过“以一定概率选择了某个主题,并从这个主题中以一定概率选择某个词语”,这样反复进行,就可以生成一篇文档;反过来,LDA又是一种非监督机器学习技术,可以识别出大规模文档集或语料库中的主题。

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()
    val df = spark.read.format("libsvm").load("lda_data.txt")
    //训练LDA模型、这一堆文本里我需要有10个topic
    val lda = new LDA().setK(10).setMaxIter(10)
    val model = lda.fit(df)
    //计算整个语料库对数可能性的下限
    val ll = model.logLikelihood(df)
    //Perplexity(困惑度)的上限,用于评估LDA主题模型好坏,判断改进的参数或者算法的建模能力。
    //通过观察Perplexity指标随topic个数的变化能够帮助我们选择合适的topic个数值,越低越好
    val lp = model.logPerplexity(df)
    println(s"Likelihood: $ll")
    println(s"Perplexity: $lp")
    // 输出主题,参数4是指定每个topic返回的词数量(已经按照权重降序排列)
    // 即这四个词与主题最相关,对主题的贡献程度最大,贡献程度分别为。。。。
    val topics = model.describeTopics(4)
    println("输出与主题最相关的四个词以及它们的权重(对主题的贡献程度):")
    topics.show(false)
    // Shows the result
    val transformed = model.transform(df)
    transformed.show(false)
  }

输出结果:

Likelihood: -841.0539703557463
Perplexity: 3.235555050621537
输出与主题最相关的四个词以及它们的权重(对主题的贡献程度):
+-----+-------------+------------------------------------------------------------------------------------+
|topic|termIndices  |termWeights                                                                         |
+-----+-------------+------------------------------------------------------------------------------------+
|0    |[2, 5, 7, 9] |[0.10606440859601529, 0.10570106168080187, 0.10430389617431601, 0.09677466095389772]|
|1    |[1, 6, 2, 5] |[0.10185076997491191, 0.09816928141878781, 0.09632454354036399, 0.09533709162736335]|
|2    |[10, 6, 9, 1]|[0.21830191650133673, 0.13864436129454022, 0.130631061595757, 0.12280252973166123]  |
|3    |[0, 4, 8, 5] |[0.10270701955806716, 0.098428481533562, 0.09815661242071609, 0.09625859107744991]  |
|4    |[9, 6, 4, 0] |[0.10452964428601186, 0.10414908178146716, 0.10103987045642693, 0.09653933325158909]|
|5    |[1, 10, 0, 6]|[0.10214945376665654, 0.10129060012341293, 0.09513643667808531, 0.09484723303591766]|
|6    |[3, 7, 4, 5] |[0.11638316687887292, 0.09901763170594163, 0.09795372072037434, 0.09538797685003378]|
|7    |[4, 0, 2, 7] |[0.1085545365386738, 0.10334275138802261, 0.10034943368678806, 0.09586142922666488] |
|8    |[0, 7, 8, 9] |[0.11008008210214115, 0.09919723498734867, 0.09810902425212233, 0.09598429155133426]|
|9    |[9, 6, 8, 7] |[0.10106110089499898, 0.10013295826865794, 0.09769277851352344, 0.09637374368101154]|
+-----+-------------+------------------------------------------------------------------------------------+

+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|label|features                                                       |topicDistribution                                                                                                                                                                                                       |
+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|0.0  |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0])      |[0.7020492743453446,0.004825993770645003,0.2593430279085318,0.004825963614085213,0.0048259356594697244,0.004825986805023069,0.004825959562158929,0.004826029371171158,0.004825898616321139,0.004825930347249381]        |
|1.0  |(11,[0,1,3,4,7,10],[1.0,3.0,1.0,3.0,2.0,1.0])                  |[0.008051227952034095,0.00805068078020596,0.9275409153537755,0.008051404074962873,0.008050494944341442,0.008050970211048909,0.008051192103744834,0.008051174355007377,0.008051123052877147,0.008050817172001883]        |
|2.0  |(11,[0,1,2,5,6,8,9],[1.0,4.0,1.0,4.0,9.0,1.0,2.0])             |[0.004196160998853775,0.004196351498154125,0.9622355518772447,0.0041960987674483745,0.004195948652190585,0.004195985718066058,0.004196035128639296,0.0041960011176528375,0.00419585892325802,0.004196007318492326]      |
|3.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,3.0,9.0])            |[0.0037108249083737366,0.003710870004756753,0.9666019491390152,0.0037108898639785916,0.003710932316515782,0.0037109398110594777,0.003710889692814606,0.0037108416716129995,0.0037109726803076204,0.0037108899115651296] |
|4.0  |(11,[0,1,2,3,4,6,9,10],[3.0,1.0,1.0,9.0,3.0,2.0,1.0,3.0])      |[0.0040206150698117475,0.004020657148286404,0.9638139184735203,0.004020683139979867,0.004020734091778302,0.004020674583024828,0.004020782022237563,0.0040206568444815,0.0040206713675701175,0.00402060725930927]        |
|5.0  |(11,[0,1,3,4,5,6,7,8,9],[4.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0,4.0]) |[0.0037112930006027146,0.0037112713017637172,0.3775560310450831,0.5927537995162975,0.003711281238052303,0.003711254221319455,0.0037112951793086854,0.0037112986549696415,0.0037112620387514264,0.003711213803851488]    |
|6.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,2.0,9.0])            |[0.0038593994507872802,0.0038594504955414776,0.9652648960492173,0.0038594706691946765,0.0038594932727853293,0.0038594785604599284,0.0038594622091066354,0.003859418239947566,0.0038594708043909586,0.003859460248568964]|
|7.0  |(11,[0,1,2,3,4,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,1.0,2.0,1.0,3.0])|[0.0043866961466148225,0.004386703747213623,0.9605195811989302,0.004386713724658389,0.004386726334830402,0.004386684736152265,0.00438695465308044,0.004386652225319148,0.0043866443293358506,0.0043866429038647205]     |
|8.0  |(11,[0,1,3,4,5,6,7],[4.0,4.0,3.0,4.0,2.0,1.0,3.0])             |[0.00438677418165635,0.004386828345825393,0.004928678499127466,0.9599768874208335,0.004386777636200246,0.0043868515295607344,0.004386859902729271,0.004386803952404989,0.004386800659948354,0.004386737871713616]       |
|9.0  |(11,[0,1,2,4,6,8,9,10],[2.0,8.0,2.0,3.0,2.0,2.0,7.0,2.0])      |[0.00332681244951176,0.003326822533688302,0.9700586410381743,0.0033268123008501223,0.0033268596356941897,0.003326828624565359,0.0033267707086431955,0.003326842280807512,0.0033268054589628416,0.003326804969102545]    |
|10.0 |(11,[0,1,2,3,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,2.0,3.0,3.0])      |[0.00419586618191812,0.004195849140887235,0.9622374586787058,0.0041958390363593,0.00419580989169427,0.0041958000694942875,0.004196080417873127,0.004195750742202915,0.004195747737696082,0.004195798103168887]          |
|11.0 |(11,[0,1,4,5,6,7,9],[4.0,1.0,4.0,5.0,1.0,3.0,1.0])             |[0.004826026955676383,0.0048259715597386765,0.005421055835370809,0.9559710760555791,0.004825952242828419,0.004825974033041617,0.0048259715984279505,0.004826101409594881,0.004825983107633105,0.004825887202109255]     |
+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

附:lda_data.txt

0 1:1 2:2 3:6 4:0 5:2 6:3 7:1 8:1 9:0 10:0 11:3
1 1:1 2:3 3:0 4:1 5:3 6:0 7:0 8:2 9:0 10:0 11:1
2 1:1 2:4 3:1 4:0 5:0 6:4 7:9 8:0 9:1 10:2 11:0
3 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:3 11:9
4 1:3 2:1 3:1 4:9 5:3 6:0 7:2 8:0 9:0 10:1 11:3
5 1:4 2:2 3:0 4:3 5:4 6:5 7:1 8:1 9:1 10:4 11:0
6 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:2 11:9
7 1:1 2:1 3:1 4:9 5:2 6:1 7:2 8:0 9:0 10:1 11:3
8 1:4 2:4 3:0 4:3 5:4 6:2 7:1 8:3 9:0 10:0 11:0
9 1:2 2:8 3:2 4:0 5:3 6:0 7:2 8:0 9:2 10:7 11:2
10 1:1 2:1 3:1 4:9 5:0 6:2 7:2 8:0 9:0 10:3 11:3
11 1:4 2:1 3:0 4:0 5:4 6:5 7:1 8:3 9:0 10:1 11:0
上一篇:TIDB分布式数据库在360金融中的应用


下一篇:【漫画】互斥锁ReentrantLock不好用?试试读写锁ReadWriteLock