Spark ML源码分析之二 从单机到分布式

        前一节从宏观角度给大家介绍了Spark ML的设计框架(链接:http://www.cnblogs.com/jicanghai/p/8570805.html),本节我们将介绍,Spark ML中,机器学习问题从单机到分布式转换的核心方法。
        单机时代,如果我们想解决一个机器学习的优化问题,最重要的就是根据训练数据,计算损失函数和梯度。由于是单机环境,什么都好说,只要公式推导没错,浮点数计算溢出问题解决好,就好了。但是,当我们的训练数据量足够大,大到单机根本存储不下的时候,对分布式学习的需求就出现了。比如电商数据,动辄上亿的训练数据量,单机望尘莫及,只能求助于分布式计算。
        那么问题来了,在分布式计算中,怎样计算得到损失函数的值,以及它的梯度值呢?这就涉及到Spark ml的一个核心,用八个字概括就是,模型集中,计算分布。具体来说,比如我们要学习一个逻辑回归模型,它的训练数据可能是存储在成百上千台服务器上,但具体的模型,只集中于一台服务器上。每次迭代时,我们现在训练数据所在的服务器上,并行的计算出,每个服务器包含的训练数据,所对应的损失函数值和梯度值,然后把这些信息集中在模型所在的机器上,进行合并,总结出所有训练数据的损失函数值和梯度值,然后对所学习的参数进行迭代,并把参数分发给拥有训练数据的服务器,并进入下一个迭代循环,直到模型收敛。
        如此看来,分布式机器学习也没有什么特别的,核心问题就在于,怎样把每个服务器上计算的损失函数值和地图值集中到模型所在的服务器上,除此之外,跟单机的机器学习问题并没有什么不同。
        这一步,在Spark ML中是如何实现的呢?这里要隆重介绍一个函数,treeAggregate,在我看来,这个函数是从单机到分布式机器学习的核心,理解了这个函数,分布式机器学习问题,就理解大半了。
        treeAggregate函数主要做什么呢?它负责把每一台服务器上的信息进行聚合,然后汇总给模型所在的服务器。拥有训练数据的服务器,可能动辄成千上万,这么多数据怎样聚合起来呢?其实函数名字已经有暗示了,它用的是树形聚合方法。假设我们有32台服务器,如果使用线性聚合,也就是说,1跟2合并,结果再跟3合并,这样一共需要进行31次合并,而且每次合并还不能并行进行,因此treeAggregate采用的方法是,把32个节点分配到一颗二叉树的32个叶子节点,然后从叶子节点开始一层一层的聚合,这样只需要5次聚合就可以了。
        具体的,使用treeAggregate函数需要定义两种运算,分别是seqOp和combOp,前者的作用是,把一个训练样本加入已有的统计,即对损失函数值和梯度进行更新,后者的作用是,把两个统计信息合并起来,可以这样理解,前者主要在单机上的统计计算时起作用,后者主要是在不同服务器进行数据合并时起作用。
        有了这些核心概念,就可以进入optim目录去一探究竟了,optim目录是Spark ML跟优化相关内容的代码库,它主要包含三部分,一是aggregator目录,二是loss目录,三是根目录,下面我们逐一介绍。
        aggregator目录下存放的是,聚合相关的代码。我们知道在机器学习任务中,不同的任务需要聚合的信息是不一样的。这里就为我们实现了几个最基本的聚合操作。其中,DifferentiableLossAggregator是基类,顾名思义,实现了最基本的可微损失函数的聚合,实际上的聚合操作都是由它的子类完成的,基类中定义了通用的merge操作,具体的add操作由各子类自己定义,代码实现都比较直接,就不一一介绍了,感兴趣的朋友可以直接读源码。
        loss目录下存放的是,损失函数相关的代码。其实,最一般性的损失函数是在breeze库中定义的,这个等我们在介绍breeze库的时候再细说。loss目录下有两个文件,一个是DifferentiableRegularization.scala,这里是把正则也当作一种损失,主要包含L2正则,另一个是RDDLossFunction.scala,这个就非常重要了,它就是应用treeAggregate函数,从单机的损失+梯度,汇总到分布式版的损失+梯度的函数,它主要应用了aggregate目录下的聚合类实现分布式的聚合运算。
        根目录下主要包含了几个优化问题的解法,最基础的是NormalEquationSolver.scala,它主要描述了一个最小二乘的标准解法,也就是正规方程的解法,其次是WeightedLeastSquares.scala,它解决了一个带权值的最小二乘问题,利用了正规方程解法,最后是IterativelyReweightedLeastSquares.scala,这是在解逻辑斯蒂回归等一大类一般性线性回归问题中常用的IRLS算法,利用了带权值的最小二乘解法。
        好,今天的介绍就到这里了。作者也是初学者,欢迎大家批评指正。
上一篇:Spark ML机器学习库评估指标示例


下一篇:Spark ML源码分析之四 树