概要
ODPS GRAPH是一套面向迭代的图计算处理框架。图计算作业使用图进行建模,图由点(Vertex)和边(Edge)组成,点和边包含权值(Value),ODPS GRAPH支持下述图编辑操作:
- 修改点或边的权值;
- 增加/删除点;
- 增加/删除边;
备注:
- 编辑点和边时,点与边的关系需要用户维护。
通过迭代对图进行编辑、演化,最终求解出结果,典型应用:PageRank,单源最短距离算法 ,K-均值聚类算法 等等。用户可以使用 ODPS GRAPH 提供的接口Java SDK编写图计算程序。
[]()Graph数据结构
ODPS GRAPH能够处理的图必须是是一个由点(Vertex)和边(Edge)组成的有向图。由于ODPS仅提供二维表的存储结构,因此需要用户自行将图数据分解为二维表格式存储在ODPS中,在进行图计算分析时,使用自定义的GraphLoader将二维表数据转换为ODPS Graph引擎中的点和边。至于如何将图数据分解为二维表格式,用户可以根据各自的业务场景做决定。在 示例程序 中,我们给出的示例分别使用不同的表格式来表达图的数据结构,仅供大家参考。
点的结构可以简单表示为 < ID, Value, Halted, Edges >,分别表示点标识符(ID),权值(Value),状态(Halted, 表示是否要停止迭代),出边集合(Edges,以该点为起始点的所有边列表)。边的结构可以简单表示为,分别表示目标点(DestVertexID)和权值(Value)。
例如,上图由下面的点组成:
Vertex | |
---|---|
v0 | <0, 0, false, [ <1, 5 >, <2, 10 > ] > |
v1 | <1, 5, false, [ <2, 3>, <3, 2>, <5, 9>]> |
v2 | <2, 8, false, [<1, 2>, <5, 1 >]> |
v3 | <3, Long.MAX_VALUE, false, [<0, 7>, <5, 6>]> |
v5 | <5, Long.MAX_VALUE, false, [<3, 4 > ]> |
[]()Graph 程序逻辑
[]()1. 加载图:
图加载:框架调用用户自定义的GraphLoader将输入表的记录解析为点或边;分布式化:框架调用用户自定义的Partitioner对点进行分片(默认分片逻辑:点ID哈希值然后对Worker数取模),分配到相应的Worker;
例如,上图假设Worker数是2,那么v0, v2会被分配到Worker0,因为ID对2取模结果为0,而v1, v3, v5将被分配到Worker1,ID对2取模结果为1;
[]()2. 迭代计算:
- 一次迭代为一个”超步”(SuperStep),遍历所有非结束状态(Halted值为false)的点或者收到消息的点(处于结束状态的点收到信息会被自动唤醒),并调用其compute(ComputeContext context, Iterable messages)方法;
-
在用户实现的compute(ComputeContext context, Iterable messages)方法中:
[]()3. 迭代终止(满足以下任意一条):
- 所有点处于结束状态(Halted值为true)且没有新消息产生;
- 达到最大迭代次数;
- 某个Aggregator的terminate方法返回true;
伪代码描述如下:
// 1. load
for each record in input_table {
GraphLoader.load();
}
// 2. setup
WorkerComputer.setup();
for each aggr in aggregators {
aggr.createStartupValue();
}
for each v in vertices {
v.setup();
}
// 3. superstep
for (step = 0; step < max; step ++) {
for each aggr in aggregators {
aggr.createInitialValue();
}
for each v in vertices {
v.compute();
}
}
// 4. cleanup
for each v in vertices {
v.cleanup();
}
WorkerComputer.cleanup();
Aggregator
Aggregator是ODPS-GRAPH作业中常用的feature之一,特别是解决机器学习问题时。ODPS-GRAPH中Aggregator用于汇总并处理全局信息。本文将详细介绍的Aggregator的执行机制、相关API,并以Kmeans Clustering为例子说明Aggregator的具体用法。
Aggregator机制
Aggregator的逻辑分两部分,一部分在所有Worker上执行,即分布式执行,另一部分只在AggregatorOwner所在Worker上执行,即单点。其中在所有Worker上执行的操作包括创建初始值及局部聚合,然后将局部聚合结果发送给AggregatorOwner所在Worker上。AggregatorOwner所在Worker上聚合普通Worker发送过来的局部聚合对象,得到全局聚合结果,然后判断迭代是否结束。全局聚合的结果会在下一轮超步分发给所有Worker,供下一轮迭代使用。 如下图所示 :
Aggregator的API
Aggregator共提供了五个API供用户实现。下面逐个介绍5个API的调用时机及常规用途。
- createStartupValue(context)
该API在所有Worker上执行一次,调用时机是所有超步开始之前,通常用以初始化AggregatorValue。在第0轮超步中,调用WorkerContext.getLastAggregatedValue() 或ComputeContext.getLastAggregatedValue()可以获取该API初始化的AggregatorValue对象。
- createInitialValue(context)
该API在所有Worker上每轮超步开始时调用一次,用以初始化本轮迭代所用的AggregatorValue。通常操作是通过WorkerContext.getLastAggregatedValue() 得到上一轮迭代的结果,然后执行部分初始化操作。
- aggregate(value, item)
该API同样在所有Worker上执行,与上述API不同的是,该API由用户显示调用ComputeContext#aggregate(item)来触发,而上述两个API,则由框架自动调用。该API用以执行局部聚合操作,其中第一个参数value是本Worker在该轮超步已经聚合的结果(初始值是createInitialValue返回的对象),第二个参数是用户代码调用ComputeContext#aggregate(item)传入的参数。该API中通常用item来更新value实现聚合。所有aggregate执行完后,得到的value就是该Worker的局部聚合结果,然后由框架发送给AggregatorOwner所在的Worker。
- merge(value, partial)
该API执行于AggregatorOwner所在Worker,用以合并各Worker局部聚合的结果,达到全局聚合对象。与aggregate类似,value是已经聚合的结果,而partial待聚合的对象,同样用partial更新value。
假定有3个worker,分别是w0、w1、w2,其局部聚合结果是p0、p1、p2。假定发送到AggregatorOwner所在Worker的顺序为p1、p0、p2。那么merge执行次序为,首先执行merge(p1, p0),这样p1和p0就聚合为p1',然后执行merge(p1', p2),p1'和p2聚合为p1'',而p1''即为本轮超步全局聚合的结果。
从上述示例可以看出,当只有一个worker时,不需要执行merge方法,也就是说merge()不会被调用。
- terminate(context, value)
当AggregatorOwner所在Worker执行完merge()后,框架会调用terminate(context, value)执行最后的处理。其中第二个参数value,即为merge()最后得到全局聚合,在该方法中可以对全局聚合继续修改。执行完terminate()后,框架会将全局聚合对象分发给所有Worker,供下一轮超步使用。
terminate()方法的一个特殊之处在于,如果返回true,则整个作业就结束迭代,否则继续执行。在机器学习场景中,通常判断收敛后返回true以结束作业。
Kmeans Clustering示例
下面以典型的KmeansClustering作为示例,来看下Aggregator具体用法。附件有完整代码,这里我们逐个部分解析代码。
- GraphLoader部分
GraphLoader部分用以加载输入表,并转换为图的点或边。这里我们输入表的每行数据为一个样本,一个样本构造一个点,并用Vertex的value来存放样本。
我们首先定义一个Writable类KmeansValue作为Vertex的value类型:`
java - static class KmeansValue implements Writable {
DenseVector sample;
public KmeansValue() {
}
public KmeansValue(DenseVector v) {
this.sample = v;
}
@Override
public void write(DataOutput out) throws IOException {
wirteForDenseVector(out, sample);
}
@Override
public void readFields(DataInput in) throws IOException {
sample = readFieldsForDenseVector(in);
}
}
KmeansValue中封装一个DenseVector对象来存放一个样本,这里DenseVector类型来自[matrix-toolkits-java](https://github.com/fommil/matrix-toolkits-java/),而wirteForDenseVector()及readFieldsForDenseVector()用以实现序列化及反序列化,可参见附件中的完整代码。<br />我们自定义的KmeansReader代码如下:<br />```java
public static class KmeansReader extends
GraphLoader<LongWritable, KmeansValue, NullWritable, NullWritable> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, KmeansValue, NullWritable, NullWritable> context)
throws IOException {
KmeansVertex v = new KmeansVertex();
v.setId(recordNum);
int n = record.size();
DenseVector dv = new DenseVector(n);
for (int i = 0; i < n; i++) {
dv.set(i, ((DoubleWritable)record.get(i)).get());
}
v.setValue(new KmeansValue(dv));
context.addVertexRequest(v);
}
}
KmeansReader中,每读入一行数据(一个Record)创建一个点,这里用recordNum作为点的ID,将record内容转换成DenseVector对象并封装进VertexValue中。
- Vertex部分
自定义的KmeansVertex代码如下。逻辑非常简单,每轮迭代要做的事情就是将自己维护的样本执行局部聚合。具体逻辑参见下面Aggregator的实现。`
java - static class KmeansVertex extends
Vertex {
@Override
public void compute(
ComputeContext<LongWritable, KmeansValue, NullWritable, NullWritable> context,
Iterable<NullWritable> messages) throws IOException {
context.aggregate(getValue());
}
}
1. Aggregator部分<br />整个Kmeans的主要逻辑集中在Aggregator中。首先是自定义的KmeansAggrValue,用以维护要聚合及分发的内容。<br />```java
public static class KmeansAggrValue implements Writable {
DenseMatrix centroids;
DenseMatrix sums; // used to recalculate new centroids
DenseVector counts; // used to recalculate new centroids
@Override
public void write(DataOutput out) throws IOException {
wirteForDenseDenseMatrix(out, centroids);
wirteForDenseDenseMatrix(out, sums);
wirteForDenseVector(out, counts);
}
@Override
public void readFields(DataInput in) throws IOException {
centroids = readFieldsForDenseMatrix(in);
sums = readFieldsForDenseMatrix(in);
counts = readFieldsForDenseVector(in);
}
}
KmeansAggrValue中维护了三个对象,其中centroids是当前的K个中心点,假定样本是m维的话,centroids就是一个K*m的矩阵。sums是和centroids大小一样的矩阵,每个元素记录了到特定中心点最近的样本特定维之和,例如sums(i,j)是到第i个中心点最近的样本的第j维度之和。
counts是个K维的向量,记录到每个中心点距离最短的样本个数。sums和counts一起用以计算新的中心点,也是要聚合的主要内容。 接下来是自定义的Aggregator实现类KmeansAggregator,我们按照上述API的顺序逐个看其实现。
首先是createStartupValue()。`
java
public static class KmeansAggregator extends Aggregator {
public KmeansAggrValue createStartupValue(WorkerContext context) throws IOException {
KmeansAggrValue av = new KmeansAggrValue();
byte[] centers = context.readCacheFile("centers");
String lines[] = new String(centers).split("n");
int rows = lines.length;
int cols = lines[0].split(",").length; // assumption rows >= 1
av.centroids = new DenseMatrix(rows, cols);
av.sums = new DenseMatrix(rows, cols);
av.sums.zero();
av.counts = new DenseVector(rows);
av.counts.zero();
for (int i = 0; i < lines.length; i++) {
String[] ss = lines[i].split(",");
for (int j = 0; j < ss.length; j++) {
av.centroids.set(i, j, Double.valueOf(ss[j]));
}
}
return av;
}
我们在该方法中初始化一个KmeansAggrValue对象,然后从资源文件centers中读取初始中心点,并赋值给centroids。而sums和counts初始化为0。<br />接来下是createInitialValue()的实现:<br />```java
@Override
public void aggregate(KmeansAggrValue value, Object item)
throws IOException {
DenseVector sample = ((KmeansValue)item).sample;
// find the nearest centroid
int min = findNearestCentroid(value.centroids, sample);
// update sum and count
for (int i = 0; i < sample.size(); i ++) {
value.sums.add(min, i, sample.get(i));
}
value.counts.add(min, 1.0d);
}
该方法中调用findNearestCentroid()(实现见附件)找到样本item欧拉距离最近的中心点索引,然后将其各个维度加到sums上,最后counts计数加1。
以上三个方法执行于所有worker上,实现局部聚合。接下来看下在AggregatorOwner所在Worker执行的全局聚合相关操作。
首先是merge的实现:`
java
@Override
public void merge(KmeansAggrValue value, KmeansAggrValue partial)
throws IOException {
value.sums.add(partial.sums);
value.counts.add(partial.counts);
}
merge的实现逻辑很简单,就是把各个worker聚合出的sums和counts相加即可。<br />最后是terminate()的实现:<br />```java
@Override
public boolean terminate(WorkerContext context, KmeansAggrValue value)
throws IOException {
// Calculate the new means to be the centroids (original sums)
DenseMatrix newCentriods = calculateNewCentroids(value.sums, value.counts, value.centroids);
// print old centroids and new centroids for debugging
System.out.println("\nsuperstep: " + context.getSuperstep() +
"\nold centriod:\n" + value.centroids + " new centriod:\n" + newCentriods);
boolean converged = isConverged(newCentriods, value.centroids, 0.05d);
System.out.println("superstep: " + context.getSuperstep() + "/"
+ (context.getMaxIteration() - 1) + " converged: " + converged);
if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
// converged or reach max iteration, output centriods
for (int i = 0; i < newCentriods.numRows(); i++) {
Writable[] centriod = new Writable[newCentriods.numColumns()];
for (int j = 0; j < newCentriods.numColumns(); j++) {
centriod[j] = new DoubleWritable(newCentriods.get(i, j));
}
context.write(centriod);
}
// true means to terminate iteration
return true;
}
// update centriods
value.centroids.set(newCentriods);
// false means to continue iteration
return false;
}
teminate()中首先根据sums和counts调用calculateNewCentroids()求平均计算出新的中心点。然后调用isConverged()根据新老中心点欧拉距离判断是否已经收敛。如果收敛或迭代次数达到最大数,则将新的中心点输出并返回true,以结束迭代。否则更新中心点并返回false以继续迭代。其中calculateNewCentroids()和isConverged()的实现见附件。
- main方法
main方法用以构造GraphJob,然后设置相应配置,并提交作业。代码如下:`
java - static void main(String[] args) throws IOException {
if (args.length < 2)
printUsage();
GraphJob job = new GraphJob();
job.setGraphLoaderClass(KmeansReader.class);
job.setRuntimePartitioning(false);
job.setVertexClass(KmeansVertex.class);
job.setAggregatorClass(KmeansAggregator.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
// default max iteration is 30
job.setMaxIteration(30);
if (args.length >= 3)
job.setMaxIteration(Integer.parseInt(args[2]));
long start = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
这里需要注意的是job.setRuntimePartitioning(false),设置为false后,各个worker加载的数据不再根据Partitioner重新分区,即谁加载的数据谁维护。
<a name="a7d80080"></a>
# 功能介绍
<a name="1922c3a4"></a>
## 运行作业
MaxCompute 客户端提供一个Jar命令用于运行 MaxCompute GRAPH作业,其使用方式与 [MapReduce](http://help.aliyun-inc.com/internaldoc/detail/27875.html)中的[Jar命令](http://help.aliyun-inc.com/internaldoc/detail/27878.html) 相同,这里仅作简要介绍:
Usage: jar [] [ARGS]
-conf <configuration_file> Specify an application configuration file
-classpath <local_file_list> classpaths used to run mainClass
-D <name>=<value> Property value pair, which will be used to run mainClass
-local Run job in local mode
-resources <resource_name_list> file/table resources used in graph, seperate by comma
其中 < GENERIC_OPTIONS>包括(均为可选参数):
* -conf <configuration file > :指定JobConf配置文件;
* -classpath <local_file_list > : 本地执行时的classpath,主要用于指定main函数所在的jar包。大多数情况下,用户更习惯于将main函数与Graph作业编写在一个包中,例如:单源最短距离算法 ,因此,在执行示例程序时,-resources及-classpath的参数中都出现了用户的jar包,但二者意义不同,-resources引用的是Graph作业,运行于分布式环境中,而-classpath引用的是main函数,运行于本地,指定的jar包路径也是本地文件路径。包名之间使用系统默认的文件分割符作分割(通常情况下,windows系统是分号”;”,linux系统是冒号”:”);
* -D <prop_name > = < prop_value > : 本地执行时,<mainClass > 的java属性,可以定义多个;
* -local:以本地模式执行Graph作业,主要用于程序调试;
* -resources <resource_name_list > : Graph作业运行时使用的资源声明。一般情况下,resource_name_list中需要指定Graph作业所在的资源名称。如果用户在Graph作业中读取了其他ODPS资源,那么,这些资源名称也需要被添加到resource_name_list中。资源之间使用逗号分隔,使用跨项目空间使用资源时,需要前面加上:PROJECT_NAME/resources/,示例:-resources otherproject/resources/resfile;
同时,用户也可以直接运行GRAPH作业的main函数直接将作业提交到 MaxCompute ,而不是通过 MaxCompute 客户端提交作业。以[PageRank算法](http://help.aliyun-inc.com/internaldoc/detail/27908.html) 为例:
public static void main(String[] args) throws IOException {
if (args.length < 2)
printUsage();
GraphJob job = new GraphJob();
job.setGraphLoaderClass(PageRankVertexReader.class);
job.setVertexClass(PageRankVertex.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
// 将作业中使用的资源添加到cache resource,对应于jar命令中 -resources 和 -libjars 中指定的资源
job.addCacheResource("mapreduce-examples.jar");
// 将使用的jar及其他文件添加到class cache resource,对应于jar命令中 -libjars 中指定的资源
job.addCacheResourceToClassPath("mapreduce-examples.jar");
// 设置console中,odps_config.ini对应的配置项,使用时替换为自己的配置
OdpsConf.getInstance().setProjName("project_name");
OdpsConf.getInstance().setEndpoint("end_point");
OdpsConf.getInstance().setAccessId("access_id");
OdpsConf.getInstance().setAccessKey("access_key");
// default max iteration is 30
job.setMaxIteration(30);
if (args.length >= 3)
job.setMaxIteration(Integer.parseInt(args[2]));
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
<a name="6354d6d6"></a>
## []()输入输出
MaxCompute GRAPH作业的输入输出限制为表,不允许用户自定义输入输出格式。<br />定义作业输入,支持多路输入:
GraphJob job = new GraphJob();
job.addInput(TableInfo.builder().tableName(“tblname”).build()); //表作为输入
job.addInput(TableInfo.builder().tableName(“tblname”).partSpec("pt1=a/pt2=b").build()); //分区作为输入
//只读取输入表的 col2 和 col0 列,在 GraphLoader 的 load 方法中,record.get(0) 得到的是col2列,顺序一致
job.addInput(TableInfo.builder().tableName(“tblname”).partSpec("pt1=a/pt2=b").build(), new String[]{"col2", "col0"});
备注:
关于作业输入定义,更多的信息参见GraphJob的addInput相关方法说明,框架读取输入表的记录传给用户自定义的GraphLoader载入图数据;
限制: 暂时不支持分区过滤条件。更多应用限制请参考 应用限制;
定义作业输出,支持多路输出,通过label标识每路输出:
GraphJob job = new GraphJob();
//输出表为分区表时需要给到最末一级分区
job.addOutput(TableInfo.builder().tableName("table_name").partSpec("pt1=a/pt2=b").build());
// 下面的参数 true 表示覆盖tableinfo指定的分区,即INSERT OVERWRITE语义,false表示INSERT INTO语义
job.addOutput(TableInfo.builder().tableName("table_name").partSpec("pt1=a/pt2=b").lable("output1").build(), true);
> 备注:
> * 关于作业输出定义,更多的信息参见GraphJob的addOutput 相关方法说明;
* Graph作业在运行时可以通过WorkerContext的write方法写出记录到输出表,多路输出需要指定标识,如上面的 “output1”;
* 更多应用限制请参考 [应用限制](http://help.aliyun-inc.com/internaldoc/detail/27905.html);
<a name="d41d8cd9"></a>
#
<a name="bbfcdb67"></a>
## 读取资源
<a name="24cb2794"></a>
### []()GRAPH程序中添加资源
除了通过jar命令指定GRAPH读取的资源外,还可以通过GraphJob的下面两个方法指定:
void addCacheResources(String resourceNames)
void addCacheResourcesToClassPath(String resourceNames)
<a name="90d49894"></a>
### []()GRAPH程序中使用资源
在 GRAPH 程序中可以通过相应的上下文对象WorkerContext的下述方法读取资源:
public byte[] readCacheFile(String resourceName) throws IOException;
public Iterable readCacheArchive(String resourceName) throws IOException;
public Iterable readCacheArchive(String resourceName, String relativePath)throws IOException;
public Iterable readResourceTable(String resourceName);
public BufferedInputStream readCacheFileAsStream(String resourceName) throws IOException;
public Iterable readCacheArchiveAsStream(String resourceName) throws IOException;
public Iterable readCacheArchiveAsStream(String resourceName, String relativePath) throws IOException;
> 备注:
> * 通常在WorkerComputer的setup方法里读取资源,然后保存在Worker Value中,之后通过getWorkerValue方法取得;
* 建议用上面的流接口,边读边处理,内存耗费少;
* 更多应用限制请参考 [应用限制](http://help.aliyun-inc.com/internaldoc/detail/27905.html);
<a name="5f839cd3"></a>
# SDK介绍
Graph SDK maven 配置:
com.aliyun.odps
odps-sdk-graph
0.20.7
sources
完整Java Doc文档,请点击 [这里](http://odps.alibaba-inc.com/doc/prddoc/odps_sdk_v2/apidocs/index.html)
| 主要接口 | 说明 |
| :--- | :--- |
| GraphJob | GraphJob继承自JobConf,用于定义、提交和管理一个 ODPS Graph 作业。 |
| Vertex | Vertex是图的点的抽象,包含属性:id,value,halted,edges,通过GraphJob的setVertexClass接口提供 Vertex 实现。 |
| Edge | Edge是图的边的抽象,包含属性:destVertexId, value,图数据结构采用邻接表,点的出边保存在点的 edges 中。 |
| GraphLoader | GraphLoader用于载入图,通过 GraphJob 的 setGraphLoaderClass 接口提供 GraphLoader 实现。 |
| VertexResolver | VertexResolver用于自定义图拓扑修改时的冲突处理逻辑,通过GraphJob的 setLoadingVertexResolverClass 和 setComputingVertexResolverClass 接口提供图加载和迭代计算过程中的图拓扑修改的冲突处理逻辑。 |
| Partitioner | Partitioner 用于对图进行划分使得计算可以分片进行,通过GraphJob的 setPartitionerClass 接口提供 Partitioner 实现,默认采用 HashPartitioner,即对点 ID 求哈希值然后对 Worker 数目取模。 |
| WorkerComputer | WorkerComputer允许在 Worker 开始和退出时执行用户自定义的逻辑,通过GraphJob的 setWorkerComputerClass 接口提供WorkerComputer 实现。 |
| Aggregator | Aggregator 的 setAggregatorClass(Class ...) 定义一个或多个 Aggregator |
| Combiner | Combiner 的 setCombinerClass 设置 Combiner |
| Counters | 计数器,在作业运行逻辑中,可以通过 WorkerContext 接口取得计数器并进行计数,框架会自动进行汇总 |
| WorkerContext | 上下文对象,封装了框架的提供的功能,如修改图拓扑结构,发送消息,写结果,读取资源等等 |
<a name="8e705b7e"></a>
# 开发和调试
ODPS没有为用户提供Graph开发插件,但用户仍然可以基于Eclipse开发ODPS Graph程序,建议的开发流程是:
* 编写Graph代码,使用本地调试进行基本的测试;
* 进行集群调试,验证结果;
<a name="9973bec7"></a>
## 开发示例
本节以[SSSP](http://help.aliyun-inc.com/internaldoc/detail/27907.html) 算法为例讲述如何用Eclipse开发和调试Graph程序。<br />下面是开发SSSP的步骤:
1. 创建Java工程,例如:graph_examples;<br />
1. 将ODPS客户端lib目录下的jar包加到Eclipse工程的Build Path里。一个配置好的Eclipse工程如下图所示。<br />
1. 开发ODPS Graph程序,实际开发过程中,常常会先拷贝一个例子(例如[SSSP](http://help.aliyun-inc.com/internaldoc/detail/27907.html)),然后再做修改。在本示例中,我们仅修改了package路径为:package com.aliyun.odps.graph.example。
1. 编译打包,在Eclipse环境中,右键点击源代码目录(图中的src目录),Export -> Java -> JAR file 生成JAR包,选择目标jar包的保存路径,例如:D:\odps\clt\odps-graph-example-sssp.jar;
1. 使用ODPS客户端运行SSSP,相关操作参考[快速开始之运行Graph](http://help.aliyun-inc.com/internaldoc/detail/27813.html)。
Eclipse 配置截图: ![](https://intranetproxy.alipay.com/skylark/lark/0/2019/png/23934/1548150832059-15fe7b48-5b7f-45b9-b9fd-5d8dd2014732.png#align=left&display=inline&height=448&originHeight=770&originWidth=1281&size=0&width=746)
> 注意:
> * 相关的开发步骤请参考[Graph开发插件介绍](http://help.aliyun-inc.com/internaldoc/detail/27985.html).
<a name="8f6be038"></a>
## 本地调试
ODPS GRAPH支持本地调试模式,可以使用Eclipse进行断点调试。<br />断点调试步骤如下:
* 下载一个odps-graph-local的maven包。
* 选择Eclipse工程,右键点击GRAPH作业主程序(包含main函数)文件,配置其运行参数(Run As -> Run Configurations…),如下图。
* 在Arguments tab页中,设置Program arguments 参数为“1 sssp_in sssp_out”,作为主程序的输入参数;
* 在Arguments tab页中,设置VM arguments参数为:<br />-Dodps.runner.mode=local -Dodps.project.name=<project.name> -Dodps.end.point=<end.point> -Dodps.access.id=<access.id> -Dodps.access.key=<access.key>
![](https://intranetproxy.alipay.com/skylark/lark/0/2019/png/23934/1548150832059-b6f956c0-ecdc-4de9-9338-3ff9e9305261.png#align=left&display=inline&height=597&originHeight=640&originWidth=800&size=0&width=746)
* 对于本地模式(即odps.end.point参数不指定),需要在warehouse创建sssp_in,sssp_out表,为输入表 sssp_in 添加数据,输入数据如下。关于warehouse的介绍请参考[MapReduce本地运行](http://help.aliyun-inc.com/internaldoc/detail/27882.html) 部分;
1,"2:2,3:1,4:4"
2,"1:2,3:2,4:1"
3,"1:1,2:2,5:1"
4,"1:4,2:1,5:1"
5,"3:1,4:1"
* 点击Run按钮即可本地跑SSSP;
其中:参数设置可参考ODPS客户端中conf/odps_config.ini的设置,上述是几个常用参数,其他参数也说明如下:
* odps.runner.mode:取值为local,本地调试功能必须指定;
* odps.project.name:指定当前project,必须指定;
* odps.end.point:指定当前odps服务的地址,可以不指定,如果不指定,只从warehouse读取表或资源的meta和数据,不存在则抛异常,如果指定,会先从warehouse读取,不存在时会远程连接odps读取;
* odps.access.id:连接odps服务的id,只在指定odps.end.point时有效;
* odps.access.key:连接odps服务的key,只在指定odps.end.point时有效;
* odps.cache.resources:指定使用的资源列表,效果与jar命令的“-resources”相同;
* odps.local.warehouse: 本地warehouse路径,不指定时默认为./warehouse;
在 Eclipse 中本地跑 SSSP的调试输出信息如下:
Counters: 3
com.aliyun.odps.graph.local.COUNTER
TASK_INPUT_BYTE=211
TASK_INPUT_RECORD=5
TASK_OUTPUT_BYTE=161
TASK_OUTPUT_RECORD=5
graph task finish
> 注意:在上面的示例中,需要本地warehouse下有sssp_in及sssp_out表。sssp_in及sssp_out的详细信息请参考[快速开始之运行Graph](http://help.aliyun-inc.com/internaldoc/detail/27813.html)中的介绍。
<a name="035a3e86"></a>
## 本地作业临时目录
每运行一次本地调试,都会在 Eclipse 工程目录下新建一个临时目录,见下图:<br />![](https://intranetproxy.alipay.com/skylark/lark/0/2019/png/23934/1548150832063-a27eee32-76f2-496b-9e7c-4aa622ebf33d.png#align=left&display=inline&height=199&originHeight=199&originWidth=272&size=0&width=272)<br />一个本地运行的GRAPH作业临时目录包括以下几个目录和文件:
* counters - 存放作业运行的一些计数信息;
* inputs - 存放作业的输入数据,优先取自本地的 warehouse,如果本地没有,会通过 ODPS SDK 从服务端读取(如果设置了 odps.end.point),默认一个 input 只读10 条数据,可以通过 -Dodps.mapred.local.record.limit 参数进行修改,但是也不能超过1万条记录;
* outputs - 存放作业的输出数据,如果本地warehouse中存在输出表,outputs里的结果数据在作业执行完后会覆盖本地warehouse中对应的表;
* resources - 存放作业使用的资源,与输入类似,优先取自本地的warehouse,如果本地没有,会通过ODPS SDK从服务端读取(如果设置了 odps.end.point);
* job.xml - 作业配置
* superstep - 存放每一轮迭代的消息持久化信息。> 注意:
> * 如果需要本地调试时输出详细日志,需要在 src 目录下放一个 log4j 的配置文件:log4j.properties_odps_graph_cluster_debug。
<a name="f949cc7a"></a>
## 集群调试
在通过本地的调试之后,可以提交作业到集群进行测试,通常步骤:
1. 配置ODPS客户端;
1. 使用“add jar /path/work.jar -f;”命令更新jar包;
1. 使用jar命令运行作业,查看运行日志和结果数据,如下所示;
> 注意:
> * 集群运行Graph的详细介绍可以参考[快速开始之运行Graph](http://help.aliyun-inc.com/internaldoc/detail/27813.html)。
<a name="5865fdf5"></a>
## 性能调优
下面主要从 ODPS Graph 框架角度介绍常见性能优化的几个方面:
<a name="5c70b720"></a>
## 作业参数配置
对性能有所影响的 GraphJob 配置项包括:
* setSplitSize(long) // 输入表切分大小,单位MB,大于0,默认64;
* setNumWorkers(int) // 设置作业worker数量,范围:[1, 1000], 默认值-1, worker数由作业输入字节数和split size决定;
* setWorkerCPU(int) // Map CPU资源,100为1cpu核,[50,800]之间,默认200;
* setWorkerMemory(int) // Map 内存资源,单位MB,[256M,12G]之间,默认4096M;
* setMaxIteration(int) // 设置最大迭代次数,默认 -1,小于或等于 0 时表示最大迭代次数不作为作业终止条件;
* setJobPriority(int) // 设置作业优先级,范围:[0, 9],默认9,数值越大优先级越小。
通常情况下:
1. 可以考虑使用setNumWorkers方法增加 worker 数目;
1. 可以考虑使用setSplitSize方法减少切分大小,提高作业载入数据速度;
1. 加大 worker 的 cpu 或内存;
1. 设置最大迭代次数,有些应用如果结果精度要求不高,可以考虑减少迭代次数,尽快结束;
接口 setNumWorkers 与 setSplitSize 配合使用,可以提高数据的载入速度。假设 setNumWorkers 为 workerNum, setSplitSize 为 splitSize, 总输入字节数为 inputSize, 则输入被切分后的块数 splitNum = inputSize / splitSize,workerNum 和 splitNum 之间的关系:
1. 若 splitNum == workerNum,每个 worker 负责载入一个 split;
1. 若 splitNum > workerNum,每个 worker 负责载入一个或多个 split;
1. 若 splitNum < workerNum, 每个 worker 负责载入零个或一个 split。
因此,应调节 workerNum 和 splitSize,在满足前两种情况时,数据载入比较快。迭代阶段只调节 workerNum 即可。 如果设置 runtime partitioning 为 false,则建议直接使用 setSplitSize 控制 worker 数量,或者保证满足前两种情况,在出现第三种情况时,部分 worker 上点数会为0. 可以在 jar 命令前使用set odps.graph.split.size=<m>; set odps.graph.worker.num=<n>; 与 setNumWorkers 和 setSplitSize 等效。<br />另外一种常见的性能问题:数据倾斜,反应到 Counters 就是某些 worker 处理的点或边数量远远超过其他 worker。<br />数据倾斜的原因通常是某些 key 对应的点、边,或者消息的数量远远超出其他 key,这些 key 被分到少量的 worker 处理,从而导致这些 worker 相对于其他运行时间长很多,解决方法:
* 可以试试 Combiner,把这些 key 对应点的消息进行本地聚合,减少消息发生;
* 改进业务逻辑。
<a name="6bb8613c"></a>
## 运用Combiner
开发人员可定义 Combiner 来减少存储消息的内存和网络数据流量,缩短作业的执行时间。细节见 SDK中Combiner的介绍。
<a name="654376db"></a>
## 减少数据输入量
数据量大时,读取磁盘中的数据可能耗费一部分处理时间,因此,减少需要读取的数据字节数可以提高总体的吞吐量,从而提高作业性能。可供选择的方法有如下几种:
* 减少输入数据量:对某些决策性质的应用,处理数据采样后子集所得到的结果只可能影响结果的精度,而并不会影响整体的准确性,因此可以考虑先对数据进行特定采样后再导入输入表中进行处理
* 避免读取用不到的字段:ODPS Graph 框架的 TableInfo 类支持读取指定的列(以列名数组方式传入),而非整个表或表分区,这样也可以减少输入的数据量,提高作业性能
<a name="e3f29de3"></a>
## 内置jar包
下面这些 jar 包会默认加载到运行 GRAPH 程序的 JVM 中,用户可以不必上传这些资源,也不必在命令行的 -libjars 带上这些 jar 包:
* commons-codec-1.3.jar
* commons-io-2.0.1.jar
* commons-lang-2.5.jar
* commons-logging-1.0.4.jar
* commons-logging-api-1.0.4.jar
* guava-14.0.jar
* json.jar
* log4j-1.2.15.jar
* slf4j-api-1.4.3.jar
* slf4j-log4j12-1.4.3.jar
* xmlenc-0.52.jar
> 注意:
> * 在起 JVM 的CLASSPATH 里,上述内置 jar 包会放在用户 jar 包的前面,所以可能产生版本冲突,例如:用户的程序中使用了 commons-codec-1.5.jar 某个类的函数,但是这个函数不在 commons-codec-1.3.jar 中,这时只能看 1.3 版本里是否有满足你需求的实现,或者等待ODPS升级新版本。
<a name="babfb10e"></a>
# 应用限制
* 单个job引用的resource数量不超过256个,table、archive按照一个单位计算;
* 单个job引用的resource总计字节数大小不超过512M;
* 单个job的输入路数不能超过1024(输入表的个数不能超过64),单个job的输出路数不能超过256;
* 多路输出中指定的label不能为null或者为空字符串,长度不能超过256,只能包括A-Z,a-z,0-9,_,#,.,-等;
* 单个job中自定义counter的数量不能超过64,counter的group name和counter name中不能带有#,两者长度和不能超过100;
* 单个job的worker数由框架计算得出,最大为 1000, 超过抛异常;
* 单个worker占用cpu默认为200,范围[50, 800];
* 单个worker占用memory默认为4096,范围[256M, 12G];
* 单个worker重复读一个resource次数限制不大于64次;
* plit size默认为64M,用户可设置,范围:0 < split_size <= (9223372036854775807 >> 20);
* ODPS Graph程序中的GraphLoader/Vertex/Aggregator等在集群运行时,受到Java沙箱的限制(Graph作业的主程序则不受此限制),具体限制如 [Java沙箱](http://help.aliyun-inc.com/internaldoc/detail/34631.html) 所示。
<a name="3e49d1b2"></a>
# 示例程序
<a name="f56cb8a8"></a>
## 单源最短距离
Dijkstra 算法是求解有向图中单源最短距离(Single Source Shortest Path,简称为 SSSP)的经典算法。<br />最短距离:对一个有权重的有向图 G=(V,E),从一个源点 s 到汇点 v 有很多路径,其中边权和最小的路径,称从 s 到 v 的最短距离。<br />算法基本原理,如下所示:
* 初始化:源点 s 到 s 自身的距离(d[s]=0),其他点 u 到 s 的距离为无穷(d[u]=∞)。<br />
* 迭代:若存在一条从 u 到 v 的边,那么从 s 到 v 的最短距离更新为:d[v]=min(d[v], d[u]+weight(u, v)),直到所有的点到 s 的距离不再发生变化时,迭代结束。<br />
由算法基本原理可以看出,此算法非常适合使用 MaxCompute Graph 程序进行求解:每个点维护到源点的当前最短距离值,当这个值变化时,将新值加上边的权值发送消息通知其邻接点,下一轮迭代时,邻接点根据收到的消息更新其当前最短距离,当所有点当前最短距离不再变化时,迭代结束。
<a name="b5ea48ff"></a>
### []()代码示例
单源最短距离的代码,如下所示:
import java.io.IOException;
import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.graph.Combiner;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.Edge;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.data.TableInfo;
public class SSSP {
public static final String START_VERTEX = "sssp.start.vertex.id";
public static class SSSPVertex extends
Vertex<LongWritable, LongWritable, LongWritable, LongWritable> {
private static long startVertexId = -1;
public SSSPVertex() {
this.setValue(new LongWritable(Long.MAX_VALUE));
}
public boolean isStartVertex(
ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context) {
if (startVertexId == -1) {
String s = context.getConfiguration().get(START_VERTEX);
startVertexId = Long.parseLong(s);
}
return getId().get() == startVertexId;
}
@Override
public void compute(
ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context,
Iterable<LongWritable> messages) throws IOException {
long minDist = isStartVertex(context) ? 0 : Integer.MAX_VALUE;
for (LongWritable msg : messages) {
if (msg.get() < minDist) {
minDist = msg.get();
}
}
if (minDist < this.getValue().get()) {
this.setValue(new LongWritable(minDist));
if (hasEdges()) {
for (Edge<LongWritable, LongWritable> e : this.getEdges()) {
context.sendMessage(e.getDestVertexId(), new LongWritable(minDist
+ e.getValue().get()));
}
}
} else {
voteToHalt();
}
}
@Override
public void cleanup(
WorkerContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
throws IOException {
context.write(getId(), getValue());
}
}
public static class MinLongCombiner extends
Combiner<LongWritable, LongWritable> {
@Override
public void combine(LongWritable vertexId, LongWritable combinedMessage,
LongWritable messageToCombine) throws IOException {
if (combinedMessage.get() > messageToCombine.get()) {
combinedMessage.set(messageToCombine.get());
}
}
}
public static class SSSPVertexReader extends
GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
throws IOException {
SSSPVertex vertex = new SSSPVertex();
vertex.setId((LongWritable) record.get(0));
String[] edges = record.get(1).toString().split(",");
for (int i = 0; i < edges.length; i++) {
String[] ss = edges[i].split(":");
vertex.addEdge(new LongWritable(Long.parseLong(ss[0])),
new LongWritable(Long.parseLong(ss[1])));
}
context.addVertexRequest(vertex);
}
}
public static void main(String[] args) throws IOException {
if (args.length < 2) {
System.out.println("Usage: <startnode> <input> <output>");
System.exit(-1);
}
GraphJob job = new GraphJob();
job.setGraphLoaderClass(SSSPVertexReader.class);
job.setVertexClass(SSSPVertex.class);
job.setCombinerClass(MinLongCombiner.class);
job.set(START_VERTEX, args[0]);
job.addInput(TableInfo.builder().tableName(args[1]).build());
job.addOutput(TableInfo.builder().tableName(args[2]).build());
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
}
上述代码,说明如下:
* 第 19 行:定义 SSSPVertex ,其中:
* 点值表示该点到源点 startVertexId 的当前最短距离。<br />
* compute() 方法使用迭代公式:d[v]=min(d[v], d[u]+weight(u, v)) 更新点值。<br />
* cleanup() 方法把点及其到源点的最短距离写到结果表中。<br />
* 第 58 行:当点值没发生变化时,调用 voteToHalt() 告诉框架该点进入 halt 状态,当所有点都进入 halt 状态时,计算结束。<br />
* 第 70 行:定义 MinLongCombiner,对发送给同一个点的消息进行合并,优化性能,减少内存占用。<br />
* 第 83 行:定义 SSSPVertexReader 类,加载图,将表中每一条记录解析为一个点,记录的第一列是点标识,第二列存储该点起始的所有的边集,内容如:2:2,3:1,4:4。<br />
* 第 106 行:主程序(main 函数),定义 GraphJob,指定 Vertex/GraphLoader/Combiner 等的实现,指定输入输出表。
<a name="PageRank"></a>
## PageRank
PageRank 算法是计算网页排名的经典算法:输入是一个有向图 G,其中顶点表示网页,如果存在网页 A 到网页 B 的链接,那么存在连接 A 到 B 的边。<br />算法基本原理,如下所示:
* 初始化:点值表示 PageRank 的 rank 值(double 类型),初始时,所有点取值为 1/TotalNumVertices。<br />
* 迭代公式:PageRank(i)=0.15/TotalNumVertices+0.85*sum,其中 sum 为所有指向 i 点的点(设为 j) PageRank(j)/out_degree(j) 的累加值。<br />
由算法基本原理可以看出,此算法非常适合使用 MaxCompute Graph 程序进行求解:每个点 j 维护其 PageRank 值,每一轮迭代都将 PageRank(j)/out_degree(j) 发给其邻接点(向其投票),下一轮迭代时,每个点根据迭代公式重新计算 PageRank 取值。
<a name="b5ea48ff"></a>
### []()代码示例
import java.io.IOException;
import org.apache.log4j.Logger;
import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.DoubleWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.io.Text;
import com.aliyun.odps.io.Writable;
public class PageRank {
private final static Logger LOG = Logger.getLogger(PageRank.class);
public static class PageRankVertex extends
Vertex<Text, DoubleWritable, NullWritable, DoubleWritable> {
@Override
public void compute(
ComputeContext<Text, DoubleWritable, NullWritable, DoubleWritable> context,
Iterable<DoubleWritable> messages) throws IOException {
if (context.getSuperstep() == 0) {
setValue(new DoubleWritable(1.0 / context.getTotalNumVertices()));
} else if (context.getSuperstep() >= 1) {
double sum = 0;
for (DoubleWritable msg : messages) {
sum += msg.get();
}
DoubleWritable vertexValue = new DoubleWritable(
(0.15f / context.getTotalNumVertices()) + 0.85f * sum);
setValue(vertexValue);
}
if (hasEdges()) {
context.sendMessageToNeighbors(this, new DoubleWritable(getValue()
.get() / getEdges().size()));
}
}
@Override
public void cleanup(
WorkerContext<Text, DoubleWritable, NullWritable, DoubleWritable> context)
throws IOException {
context.write(getId(), getValue());
}
}
public static class PageRankVertexReader extends
GraphLoader<Text, DoubleWritable, NullWritable, DoubleWritable> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<Text, DoubleWritable, NullWritable, DoubleWritable> context)
throws IOException {
PageRankVertex vertex = new PageRankVertex();
vertex.setValue(new DoubleWritable(0));
vertex.setId((Text) record.get(0));
System.out.println(record.get(0));
for (int i = 1; i < record.size(); i++) {
Writable edge = record.get(i);
System.out.println(edge.toString());
if (!(edge.equals(NullWritable.get()))) {
vertex.addEdge(new Text(edge.toString()), NullWritable.get());
}
}
LOG.info("vertex edgs size: "
+ (vertex.hasEdges() ? vertex.getEdges().size() : 0));
context.addVertexRequest(vertex);
}
}
private static void printUsage() {
System.out.println("Usage: <in> <out> [Max iterations (default 30)]");
System.exit(-1);
}
public static void main(String[] args) throws IOException {
if (args.length < 2)
printUsage();
GraphJob job = new GraphJob();
job.setGraphLoaderClass(PageRankVertexReader.class);
job.setVertexClass(PageRankVertex.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
// default max iteration is 30
job.setMaxIteration(30);
if (args.length >= 3)
job.setMaxIteration(Integer.parseInt(args[2]));
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
}
上述代码,说明如下:
* 第 23 行:定义 PageRankVertex ,其中:
* 点值表示该点(网页)的当前 PageRank 取值。<br />
* compute() 方法使用迭代公式:`PageRank(i)=0.15/TotalNumVertices+0.85*sum`更新点值。<br />
* cleanup() 方法把点及其 PageRank 取值写到结果表中。<br />
* 第 55 行:定义 PageRankVertexReader 类,加载图,将表中每一条记录解析为一个点,记录的第一列是起点,其他列为终点。<br />
* 第 88 行:主程序(main 函数),定义 GraphJob,指定 Vertex/GraphLoader 等的实现,以及最大迭代次数(默认 30),并指定输入输出表。
<a name="5732cb27"></a>
## K-均值聚类
k-均值聚类(Kmeans) 算法是非常基础并大量使用的聚类算法。<br />算法基本原理:以空间中 k 个点为中心进行聚类,对最靠近它们的点进行归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。<br />假设要把样本集分为 k 个类别,算法描述如下:
1. 适当选择 k 个类的初始中心。<br />
1. 在第 i 次迭代中,对任意一个样本,求其到 k 个中心的距离,将该样本归到距离最短的中心所在的类。<br />
1. 利用均值等方法更新该类的中心值。<br />
1. 对于所有的 k 个聚类中心,如果利用上两步的迭代法更新后,值保持不变或者小于某个阈值,则迭代结束,否则继续迭代。<br />
<a name="b5ea48ff"></a>
### []()代码示例
K-均值聚类算法的代码,如下所示:
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.log4j.Logger;
import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.DoubleWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.io.Text;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;
public class Kmeans {
private final static Logger LOG = Logger.getLogger(Kmeans.class);
public static class KmeansVertex extends
Vertex<Text, Tuple, NullWritable, NullWritable> {
@Override
public void compute(
ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
Iterable<NullWritable> messages) throws IOException {
context.aggregate(getValue());
}
}
public static class KmeansVertexReader extends
GraphLoader<Text, Tuple, NullWritable, NullWritable> {
@Override
public void load(LongWritable recordNum, WritableRecord record,
MutationContext<Text, Tuple, NullWritable, NullWritable> context)
throws IOException {
KmeansVertex vertex = new KmeansVertex();
vertex.setId(new Text(String.valueOf(recordNum.get())));
vertex.setValue(new Tuple(record.getAll()));
context.addVertexRequest(vertex);
}
}
public static class KmeansAggrValue implements Writable {
Tuple centers = new Tuple();
Tuple sums = new Tuple();
Tuple counts = new Tuple();
@Override
public void write(DataOutput out) throws IOException {
centers.write(out);
sums.write(out);
counts.write(out);
}
@Override
public void readFields(DataInput in) throws IOException {
centers = new Tuple();
centers.readFields(in);
sums = new Tuple();
sums.readFields(in);
counts = new Tuple();
counts.readFields(in);
}
@Override
public String toString() {
return "centers " + centers.toString() + ", sums " + sums.toString()
+ ", counts " + counts.toString();
}
}
public static class KmeansAggregator extends Aggregator {
@SuppressWarnings("rawtypes")
@Override
public KmeansAggrValue createInitialValue(WorkerContext context)
throws IOException {
KmeansAggrValue aggrVal = null;
if (context.getSuperstep() == 0) {
aggrVal = new KmeansAggrValue();
aggrVal.centers = new Tuple();
aggrVal.sums = new Tuple();
aggrVal.counts = new Tuple();
byte[] centers = context.readCacheFile("centers");
String lines[] = new String(centers).split("\n");
for (int i = 0; i < lines.length; i++) {
String[] ss = lines[i].split(",");
Tuple center = new Tuple();
Tuple sum = new Tuple();
for (int j = 0; j < ss.length; ++j) {
center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
sum.append(new DoubleWritable(0.0));
}
LongWritable count = new LongWritable(0);
aggrVal.sums.append(sum);
aggrVal.counts.append(count);
aggrVal.centers.append(center);
}
} else {
aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
}
return aggrVal;
}
@Override
public void aggregate(KmeansAggrValue value, Object item) {
int min = 0;
double mindist = Double.MAX_VALUE;
Tuple point = (Tuple) item;
for (int i = 0; i < value.centers.size(); i++) {
Tuple center = (Tuple) value.centers.get(i);
// use Euclidean Distance, no need to calculate sqrt
double dist = 0.0d;
for (int j = 0; j < center.size(); j++) {
double v = ((DoubleWritable) point.get(j)).get()
- ((DoubleWritable) center.get(j)).get();
dist += v * v;
}
if (dist < mindist) {
mindist = dist;
min = i;
}
}
// update sum and count
Tuple sum = (Tuple) value.sums.get(min);
for (int i = 0; i < point.size(); i++) {
DoubleWritable s = (DoubleWritable) sum.get(i);
s.set(s.get() + ((DoubleWritable) point.get(i)).get());
}
LongWritable count = (LongWritable) value.counts.get(min);
count.set(count.get() + 1);
}
@Override
public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
for (int i = 0; i < value.sums.size(); i++) {
Tuple sum = (Tuple) value.sums.get(i);
Tuple that = (Tuple) partial.sums.get(i);
for (int j = 0; j < sum.size(); j++) {
DoubleWritable s = (DoubleWritable) sum.get(j);
s.set(s.get() + ((DoubleWritable) that.get(j)).get());
}
}
for (int i = 0; i < value.counts.size(); i++) {
LongWritable count = (LongWritable) value.counts.get(i);
count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
}
}
@SuppressWarnings("rawtypes")
@Override
public boolean terminate(WorkerContext context, KmeansAggrValue value)
throws IOException {
// compute new centers
Tuple newCenters = new Tuple(value.sums.size());
for (int i = 0; i < value.sums.size(); i++) {
Tuple sum = (Tuple) value.sums.get(i);
Tuple newCenter = new Tuple(sum.size());
LongWritable c = (LongWritable) value.counts.get(i);
for (int j = 0; j < sum.size(); j++) {
DoubleWritable s = (DoubleWritable) sum.get(j);
double val = s.get() / c.get();
newCenter.set(j, new DoubleWritable(val));
// reset sum for next iteration
s.set(0.0d);
}
// reset count for next iteration
c.set(0);
newCenters.set(i, newCenter);
}
// update centers
Tuple oldCenters = value.centers;
value.centers = newCenters;
LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);
// compare new/old centers
boolean converged = true;
for (int i = 0; i < value.centers.size() && converged; i++) {
Tuple oldCenter = (Tuple) oldCenters.get(i);
Tuple newCenter = (Tuple) newCenters.get(i);
double sum = 0.0d;
for (int j = 0; j < newCenter.size(); j++) {
double v = ((DoubleWritable) newCenter.get(j)).get()
- ((DoubleWritable) oldCenter.get(j)).get();
sum += v * v;
}
double dist = Math.sqrt(sum);
LOG.info("old center: " + oldCenter + ", new center: " + newCenter
+ ", dist: " + dist);
// converge threshold for each center: 0.05
converged = dist < 0.05d;
}
if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
// converged or reach max iteration, output centers
for (int i = 0; i < value.centers.size(); i++) {
context.write(((Tuple) value.centers.get(i)).toArray());
}
// true means to terminate iteration
return true;
}
// false means to continue iteration
return false;
}
}
private static void printUsage() {
System.out.println("Usage: <in> <out> [Max iterations (default 30)]");
System.exit(-1);
}
public static void main(String[] args) throws IOException {
if (args.length < 2)
printUsage();
GraphJob job = new GraphJob();
job.setGraphLoaderClass(KmeansVertexReader.class);
job.setRuntimePartitioning(false);
job.setVertexClass(KmeansVertex.class);
job.setAggregatorClass(KmeansAggregator.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
// default max iteration is 30
job.setMaxIteration(30);
if (args.length >= 3)
job.setMaxIteration(Integer.parseInt(args[2]));
long start = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
<br />上述代码,说明如下:
* 第 26 行:定义 KmeansVertex,compute() 方法非常简单,只是调用上下文对象的 aggregate 方法,传入当前点的取值(Tuple 类型,向量表示)。<br />
* 第 38 行:定义 KmeansVertexReader 类,加载图,将表中每一条记录解析为一个点,点标识无关紧要,这里取传入的 recordNum 序号作为标识,点值为记录的所有列组成的 Tuple。<br />
* 第 83 行:定义 KmeansAggregator,这个类封装了 Kmeans 算法的主要逻辑,其中:
* createInitialValue 为每一轮迭代创建初始值(k 类中心点),若是第一轮迭代(superstep=0),该取值为初始中心点,否则取值为上一轮结束时的新中心点。<br />
* aggregate 方法为每个点计算其到各个类中心的距离,并归为距离最短的类,并更新该类的 sum 和 count。<br />
* merge 方法合并来自各个 worker 收集的 sum 和 count。<br />
* terminate 方法根据各个类的 sum 和 count 计算新的中心点,若新中心点与之前的中心点距离小于某个阈值或者迭代次数到达最大迭代次数设置,则终止迭代(返回 false),写最终的中心点到结果表。<br />
* 第 236 行:主程序(main 函数),定义 GraphJob,指定 Vertex/GraphLoader/Aggregator 等的实现,以及最大迭代次数(默认 30),并指定输入输出表。<br />
* 第 243 行:job.setRuntimePartitioning(false),对于 Kmeans 算法,加载图是不需要进行点的分发,设置 RuntimePartitioning 为 false,以提升加载图时的性能。
<a name="BiPartiteMatchiing"></a>
## BiPartiteMatchiing
二分图是指图的所有顶点可分为两个集合,每条边对应的两个顶点分别属于这两个集合。对于一个二分图 G,M 是它的一个子图,如果 M 的边集中任意两条边都不依附于同一个顶点,则称 M 为一个匹配。二分图匹配常用于有明确供需关系场景(如交友网站等)下的信息匹配行为。<br />算法描述,如下所示:
* 从左边第 1 个顶点开始,挑选未匹配点进行搜索,寻找增广路。<br />
* 如果经过一个未匹配点,说明寻找成功。<br />
* 更新路径信息,匹配边数 +1,停止搜索。<br />
* 如果一直没有找到增广路,则不再从这个点开始搜索。<br />
<a name="b5ea48ff"></a>
### []()代码示例
BiPartiteMatchiing 算法的代码,如下所示:
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Random;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.Text;
import com.aliyun.odps.io.Writable;
import com.aliyun.odps.io.WritableRecord;
public class BipartiteMatching {
private static final Text UNMATCHED = new Text("UNMATCHED");
public static class TextPair implements Writable {
public Text first;
public Text second;
public TextPair() {
first = new Text();
second = new Text();
}
public TextPair(Text first, Text second) {
this.first = new Text(first);
this.second = new Text(second);
}
@Override
public void write(DataOutput out) throws IOException {
first.write(out);
second.write(out);
}
@Override
public void readFields(DataInput in) throws IOException {
first = new Text();
first.readFields(in);
second = new Text();
second.readFields(in);
}
@Override
public String toString() {
return first + ": " + second;
}
}
public static class BipartiteMatchingVertexReader extends
GraphLoader<Text, TextPair, NullWritable, Text> {
@Override
public void load(LongWritable recordNum, WritableRecord record,
MutationContext<Text, TextPair, NullWritable, Text> context)
throws IOException {
BipartiteMatchingVertex vertex = new BipartiteMatchingVertex();
vertex.setId((Text) record.get(0));
vertex.setValue(new TextPair(UNMATCHED, (Text) record.get(1)));
String[] adjs = record.get(2).toString().split(",");
for (String adj : adjs) {
vertex.addEdge(new Text(adj), null);
}
context.addVertexRequest(vertex);
}
}
public static class BipartiteMatchingVertex extends
Vertex<Text, TextPair, NullWritable, Text> {
private static final Text LEFT = new Text("LEFT");
private static final Text RIGHT = new Text("RIGHT");
private static Random rand = new Random();
@Override
public void compute(
ComputeContext<Text, TextPair, NullWritable, Text> context,
Iterable<Text> messages) throws IOException {
if (isMatched()) {
voteToHalt();
return;
}
switch ((int) context.getSuperstep() % 4) {
case 0:
if (isLeft()) {
context.sendMessageToNeighbors(this, getId());
}
break;
case 1:
if (isRight()) {
Text luckyLeft = null;
for (Text message : messages) {
if (luckyLeft == null) {
luckyLeft = new Text(message);
} else {
if (rand.nextInt(1) == 0) {
luckyLeft.set(message);
}
}
}
if (luckyLeft != null) {
context.sendMessage(luckyLeft, getId());
}
}
break;
case 2:
if (isLeft()) {
Text luckyRight = null;
for (Text msg : messages) {
if (luckyRight == null) {
luckyRight = new Text(msg);
} else {
if (rand.nextInt(1) == 0) {
luckyRight.set(msg);
}
}
}
if (luckyRight != null) {
setMatchVertex(luckyRight);
context.sendMessage(luckyRight, getId());
}
}
break;
case 3:
if (isRight()) {
for (Text msg : messages) {
setMatchVertex(msg);
}
}
break;
}
}
@Override
public void cleanup(
WorkerContext<Text, TextPair, NullWritable, Text> context)
throws IOException {
context.write(getId(), getValue().first);
}
private boolean isMatched() {
return !getValue().first.equals(UNMATCHED);
}
private boolean isLeft() {
return getValue().second.equals(LEFT);
}
private boolean isRight() {
return getValue().second.equals(RIGHT);
}
private void setMatchVertex(Text matchVertex) {
getValue().first.set(matchVertex);
}
}
private static void printUsage() {
System.err.println("BipartiteMatching <input> <output> [maxIteration]");
}
public static void main(String[] args) throws IOException {
if (args.length < 2) {
printUsage();
}
GraphJob job = new GraphJob();
job.setGraphLoaderClass(BipartiteMatchingVertexReader.class);
job.setVertexClass(BipartiteMatchingVertex.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
int maxIteration = 30;
if (args.length > 2) {
maxIteration = Integer.parseInt(args[2]);
}
job.setMaxIteration(maxIteration);
job.run();
}
}
<a name="d41d8cd9"></a>
##