示例程序
强连通分量
在有向图中,如果从任意一个顶点出发,都能通过图中的边到达图中的每一个顶点,则称之为强连通图。一张有向图的顶点数极大的强连通子图称为强连通分量。此算法示例基于 parallel Coloring algorithm。
每个顶点包含两个部分,如下所示:
colorID:在向前遍历过程中存储顶点 v 的颜色,在计算结束时,具有相同 colorID 的顶点属于一个强连通分量。
transposeNeighbors:存储输入图的转置图中顶点 v 的邻居 ID。
算法包含以下四部分:
生成转置图:包含两个超步,首先每个顶点发送 ID 到其出边对应的邻居,这些 ID 在第二个超步中会存为 transposeNeighbors 值。
修剪:一个超步,每个只有一个入边或出边的顶点,将其 colorID 设为自身 ID,状态设为不活跃,后面传给该顶点的信号被忽略。
向前遍历:顶点包括两个子过程(超步),启动和休眠。在启动阶段,每个顶点将其 colorID 设置为自身 ID,同时将其 ID 传给出边对应的邻居。休眠阶段,顶点使用其收到的最大 colorID 更新自身 colorID,并传播其 colorID,直到 colorID 收敛。当 colorID 收敛,master 进程将全局对象设置为向后遍历。
向后遍历:同样包含两个子过程,启动和休眠。启动阶段,每一个 ID 等于 colorID 的顶点将其 ID 传递给其转置图邻居顶点,同时将自身状态设置为不活跃,后面传给该顶点的信号可忽略。在每一个休眠步,每个顶点接收到与其 colorID 匹配的信号,并将其 colorID 在转置图中传播,随后设置自身状态为不活跃。该步结束后如果仍有活跃顶点,则回到修剪步。
代码示例
强连通分量的代码,如下所示:
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.BooleanWritable;
import com.aliyun.odps.io.IntWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;
import com.aliyun.odps.io.WritableRecord;
/**
* Definition from Wikipedia:
* In the mathematical theory of directed graphs, a graph is said
* to be strongly connected if every vertex is reachable from every
* other vertex. The strongly connected components of an arbitrary
* directed graph form a partition into subgraphs that are themselves
* strongly connected.
*
* Algorithms with four phases as follows.
* 1. Transpose Graph Formation: Requires two supersteps. In the first
* superstep, each vertex sends a message with its ID to all its outgoing
* neighbors, which in the second superstep are stored in transposeNeighbors.
*
* 2. Trimming: Takes one superstep. Every vertex with only in-coming or
* only outgoing edges (or neither) sets its colorID to its own ID and
* becomes inactive. Messages subsequently sent to the vertex are ignored.
*
* 3. Forward-Traversal: There are two sub phases: Start and Rest. In the
* Start phase, each vertex sets its colorID to its own ID and propagates
* its ID to its outgoing neighbors. In the Rest phase, vertices update
* their own colorIDs with the minimum colorID they have seen, and propagate
* their colorIDs, if updated, until the colorIDs converge.
* Set the phase to Backward-Traversal when the colorIDs converge.
*
* 4. Backward-Traversal: We again break the phase into Start and Rest.
* In Start, every vertex whose ID equals its colorID propagates its ID to
* the vertices in transposeNeighbors and sets itself inactive. Messages
* subsequently sent to the vertex are ignored. In each of the Rest phase supersteps,
* each vertex receiving a message that matches its colorID: (1) propagates
* its colorID in the transpose graph; (2) sets itself inactive. Messages
* subsequently sent to the vertex are ignored. Set the phase back to Trimming
* if not all vertex are inactive.
*
* http://ilpubs.stanford.edu:8090/1077/3/p535-salihoglu.pdf
*/
public class StronglyConnectedComponents {
public final static int STAGE_TRANSPOSE_1 = 0;
public final static int STAGE_TRANSPOSE_2 = 1;
public final static int STAGE_TRIMMING = 2;
public final static int STAGE_FW_START = 3;
public final static int STAGE_FW_REST = 4;
public final static int STAGE_BW_START = 5;
public final static int STAGE_BW_REST = 6;
/**
* The value is composed of component id, incoming neighbors,
* active status and updated status.
*/
public static class MyValue implements Writable {
LongWritable sccID;// strongly connected component id
Tuple inNeighbors; // transpose neighbors
BooleanWritable active; // vertex is active or not
BooleanWritable updated; // sccID is updated or not
public MyValue() {
this.sccID = new LongWritable(Long.MAX_VALUE);
this.inNeighbors = new Tuple();
this.active = new BooleanWritable(true);
this.updated = new BooleanWritable(false);
}
public void setSccID(LongWritable sccID) {
this.sccID = sccID;
}
public LongWritable getSccID() {
return this.sccID;
}
public void setInNeighbors(Tuple inNeighbors) {
this.inNeighbors = inNeighbors;
}
public Tuple getInNeighbors() {
return this.inNeighbors;
}
public void addInNeighbor(LongWritable neighbor) {
this.inNeighbors.append(new LongWritable(neighbor.get()));
}
public boolean isActive() {
return this.active.get();
}
public void setActive(boolean status) {
this.active.set(status);
}
public boolean isUpdated() {
return this.updated.get();
}
public void setUpdated(boolean update) {
this.updated.set(update);
}
@Override
public void write(DataOutput out) throws IOException {
this.sccID.write(out);
this.inNeighbors.write(out);
this.active.write(out);
this.updated.write(out);
}
@Override
public void readFields(DataInput in) throws IOException {
this.sccID.readFields(in);
this.inNeighbors.readFields(in);
this.active.readFields(in);
this.updated.readFields(in);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("sccID: " + sccID.get());
sb.append(" inNeighbores: " + inNeighbors.toDelimitedString(','));
sb.append(" active: " + active.get());
sb.append(" updated: " + updated.get());
return sb.toString();
}
}
public static class SCCVertex extends
Vertex<LongWritable, MyValue, NullWritable, LongWritable> {
public SCCVertex() {
this.setValue(new MyValue());
}
@Override
public void compute(
ComputeContext<LongWritable, MyValue, NullWritable, LongWritable> context,
Iterable<LongWritable> msgs) throws IOException {
// Messages sent to inactive vertex are ignored.
if (!this.getValue().isActive()) {
this.voteToHalt();
return;
}
int stage = ((SCCAggrValue)context.getLastAggregatedValue(0)).getStage();
switch (stage) {
case STAGE_TRANSPOSE_1:
context.sendMessageToNeighbors(this, this.getId());
break;
case STAGE_TRANSPOSE_2:
for (LongWritable msg: msgs) {
this.getValue().addInNeighbor(msg);
}
case STAGE_TRIMMING:
this.getValue().setSccID(getId());
if (this.getValue().getInNeighbors().size() == 0 ||
this.getNumEdges() == 0) {
this.getValue().setActive(false);
}
break;
case STAGE_FW_START:
this.getValue().setSccID(getId());
context.sendMessageToNeighbors(this, this.getValue().getSccID());
break;
case STAGE_FW_REST:
long minSccID = Long.MAX_VALUE;
for (LongWritable msg : msgs) {
if (msg.get() < minSccID) {
minSccID = msg.get();
}
}
if (minSccID < this.getValue().getSccID().get()) {
this.getValue().setSccID(new LongWritable(minSccID));
context.sendMessageToNeighbors(this, this.getValue().getSccID());
this.getValue().setUpdated(true);
} else {
this.getValue().setUpdated(false);
}
break;
case STAGE_BW_START:
if (this.getId().equals(this.getValue().getSccID())) {
for (Writable neighbor : this.getValue().getInNeighbors().getAll()) {
context.sendMessage((LongWritable)neighbor, this.getValue().getSccID());
}
this.getValue().setActive(false);
}
break;
case STAGE_BW_REST:
this.getValue().setUpdated(false);
for (LongWritable msg : msgs) {
if (msg.equals(this.getValue().getSccID())) {
for (Writable neighbor : this.getValue().getInNeighbors().getAll()) {
context.sendMessage((LongWritable)neighbor, this.getValue().getSccID());
}
this.getValue().setActive(false);
this.getValue().setUpdated(true);
break;
}
}
break;
}
context.aggregate(0, getValue());
}
@Override
public void cleanup(
WorkerContext<LongWritable, MyValue, NullWritable, LongWritable> context)
throws IOException {
context.write(getId(), getValue().getSccID());
}
}
/**
* The SCCAggrValue maintains global stage and graph updated and active status.
* updated is true only if one vertex is updated.
* active is true only if one vertex is active.
*/
public static class SCCAggrValue implements Writable {
IntWritable stage = new IntWritable(STAGE_TRANSPOSE_1);
BooleanWritable updated = new BooleanWritable(false);
BooleanWritable active = new BooleanWritable(false);
public void setStage(int stage) {
this.stage.set(stage);
}
public int getStage() {
return this.stage.get();
}
public void setUpdated(boolean updated) {
this.updated.set(updated);
}
public boolean getUpdated() {
return this.updated.get();
}
public void setActive(boolean active) {
this.active.set(active);
}
public boolean getActive() {
return this.active.get();
}
@Override
public void write(DataOutput out) throws IOException {
this.stage.write(out);
this.updated.write(out);
this.active.write(out);
}
@Override
public void readFields(DataInput in) throws IOException {
this.stage.readFields(in);
this.updated.readFields(in);
this.active.readFields(in);
}
}
/**
* The job of SCCAggregator is to schedule global stage in every superstep.
*/
public static class SCCAggregator extends Aggregator<SCCAggrValue> {
@SuppressWarnings("rawtypes")
@Override
public SCCAggrValue createStartupValue(WorkerContext context) throws IOException {
return new SCCAggrValue();
}
@SuppressWarnings("rawtypes")
@Override
public SCCAggrValue createInitialValue(WorkerContext context)
throws IOException {
return (SCCAggrValue) context.getLastAggregatedValue(0);
}
@Override
public void aggregate(SCCAggrValue value, Object item) throws IOException {
MyValue v = (MyValue)item;
if ((value.getStage() == STAGE_FW_REST || value.getStage() == STAGE_BW_REST)
&& v.isUpdated()) {
value.setUpdated(true);
}
// only active vertex invoke aggregate()
value.setActive(true);
}
@Override
public void merge(SCCAggrValue value, SCCAggrValue partial)
throws IOException {
boolean updated = value.getUpdated() || partial.getUpdated();
value.setUpdated(updated);
boolean active = value.getActive() || partial.getActive();
value.setActive(active);
}
@SuppressWarnings("rawtypes")
@Override
public boolean terminate(WorkerContext context, SCCAggrValue value)
throws IOException {
// If all vertices is inactive, job is over.
if (!value.getActive()) {
return true;
}
// state machine
switch (value.getStage()) {
case STAGE_TRANSPOSE_1:
value.setStage(STAGE_TRANSPOSE_2);
break;
case STAGE_TRANSPOSE_2:
value.setStage(STAGE_TRIMMING);
break;
case STAGE_TRIMMING:
value.setStage(STAGE_FW_START);
break;
case STAGE_FW_START:
value.setStage(STAGE_FW_REST);
break;
case STAGE_FW_REST:
if (value.getUpdated()) {
value.setStage(STAGE_FW_REST);
} else {
value.setStage(STAGE_BW_START);
}
break;
case STAGE_BW_START:
value.setStage(STAGE_BW_REST);
break;
case STAGE_BW_REST:
if (value.getUpdated()) {
value.setStage(STAGE_BW_REST);
} else {
value.setStage(STAGE_TRIMMING);
}
break;
}
value.setActive(false);
value.setUpdated(false);
return false;
}
}
public static class SCCVertexReader extends
GraphLoader<LongWritable, MyValue, NullWritable, LongWritable> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, MyValue, NullWritable, LongWritable> context)
throws IOException {
SCCVertex vertex = new SCCVertex();
vertex.setId((LongWritable) record.get(0));
String[] edges = record.get(1).toString().split(",");
for (int i = 0; i < edges.length; i++) {
try {
long destID = Long.parseLong(edges[i]);
vertex.addEdge(new LongWritable(destID), NullWritable.get());
} catch(NumberFormatException nfe) {
System.err.println("Ignore " + nfe);
}
}
context.addVertexRequest(vertex);
}
}
public static void main(String[] args) throws IOException {
if (args.length < 2) {
System.out.println("Usage: <input> <output>");
System.exit(-1);
}
GraphJob job = new GraphJob();
job.setGraphLoaderClass(SCCVertexReader.class);
job.setVertexClass(SCCVertex.class);
job.setAggregatorClass(SCCAggregator.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
}
连通分量
两个顶点之间存在路径,称两个顶点为连通的。如果无向图 G 中任意两个顶点都是连通的,则称 G 为连通图,否则称为非连通图。其顶点个数极大的连通子图称为连通分量。
本算法计算每个点的连通分量成员,最后输出顶点值中包含最小顶点 ID 的连通分量。将最小顶点 ID 沿着边传播到连通分量的所有顶点。
代码示例
连通分量的代码,如下所示:
import java.io.IOException;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.graph.examples.SSSP.MinLongCombiner;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.WritableRecord;
/**
* Compute the connected component membership of each vertex and output
* each vertex which's value containing the smallest id in the connected
* component containing that vertex.
*
* Algorithm: propagate the smallest vertex id along the edges to all
* vertices of a connected component.
*
*/
public class ConnectedComponents {
public static class CCVertex extends
Vertex<LongWritable, LongWritable, NullWritable, LongWritable> {
@Override
public void compute(
ComputeContext<LongWritable, LongWritable, NullWritable, LongWritable> context,
Iterable<LongWritable> msgs) throws IOException {
if (context.getSuperstep() == 0L) {
this.setValue(getId());
context.sendMessageToNeighbors(this, getValue());
return;
}
long minID = Long.MAX_VALUE;
for (LongWritable id : msgs) {
if (id.get() < minID) {
minID = id.get();
}
}
if (minID < this.getValue().get()) {
this.setValue(new LongWritable(minID));
context.sendMessageToNeighbors(this, getValue());
} else {
this.voteToHalt();
}
}
/**
* Output Table Description:
* +-----------------+----------------------------------------+
* | Field | Type | Comment |
* +-----------------+----------------------------------------+
* | v | bigint | vertex id |
* | minID | bigint | smallest id in the connected component |
* +-----------------+----------------------------------------+
*/
@Override
public void cleanup(
WorkerContext<LongWritable, LongWritable, NullWritable, LongWritable> context)
throws IOException {
context.write(getId(), getValue());
}
}
/**
* Input Table Description:
* +-----------------+----------------------------------------------------+
* | Field | Type | Comment |
* +-----------------+----------------------------------------------------+
* | v | bigint | vertex id |
* | es | string | comma separated target vertex id of outgoing edges |
* +-----------------+----------------------------------------------------+
*
* Example:
* For graph:
* 1 ----- 2
* | |
* 3 ----- 4
* Input table:
* +-----------+
* | v | es |
* +-----------+
* | 1 | 2,3 |
* | 2 | 1,4 |
* | 3 | 1,4 |
* | 4 | 2,3 |
* +-----------+
*/
public static class CCVertexReader extends
GraphLoader<LongWritable, LongWritable, NullWritable, LongWritable> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, LongWritable, NullWritable, LongWritable> context)
throws IOException {
CCVertex vertex = new CCVertex();
vertex.setId((LongWritable) record.get(0));
String[] edges = record.get(1).toString().split(",");
for (int i = 0; i < edges.length; i++) {
long destID = Long.parseLong(edges[i]);
vertex.addEdge(new LongWritable(destID), NullWritable.get());
}
context.addVertexRequest(vertex);
}
}
public static void main(String[] args) throws IOException {
if (args.length < 2) {
System.out.println("Usage: <input> <output>");
System.exit(-1);
}
GraphJob job = new GraphJob();
job.setGraphLoaderClass(CCVertexReader.class);
job.setVertexClass(CCVertex.class);
job.setCombinerClass(MinLongCombiner.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
}
拓扑排序
对于有向边(u,v),定义所有满足 u算法步骤如下:
[]()代码示例
拓扑排序算法的代码,如下所示:
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.Combiner;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.BooleanWritable;
import com.aliyun.odps.io.WritableRecord;
public class TopologySort {
private final static Log LOG = LogFactory.getLog(TopologySort.class);
public static class TopologySortVertex extends
Vertex<LongWritable, LongWritable, NullWritable, LongWritable> {
@Override
public void compute(
ComputeContext<LongWritable, LongWritable, NullWritable, LongWritable> context,
Iterable<LongWritable> messages) throws IOException {
// in superstep 0, each vertex sends message whose value is 1 to its
// neighbors
if (context.getSuperstep() == 0) {
if (hasEdges()) {
context.sendMessageToNeighbors(this, new LongWritable(1L));
}
} else if (context.getSuperstep() >= 1) {
// compute each vertex's indegree
long indegree = getValue().get();
for (LongWritable msg : messages) {
indegree += msg.get();
}
setValue(new LongWritable(indegree));
if (indegree == 0) {
voteToHalt();
if (hasEdges()) {
context.sendMessageToNeighbors(this, new LongWritable(-1L));
}
context.write(new LongWritable(context.getSuperstep()), getId());
LOG.info("vertex: " + getId());
}
context.aggregate(new LongWritable(indegree));
}
}
}
public static class TopologySortVertexReader extends
GraphLoader<LongWritable, LongWritable, NullWritable, LongWritable> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, LongWritable, NullWritable, LongWritable> context)
throws IOException {
TopologySortVertex vertex = new TopologySortVertex();
vertex.setId((LongWritable) record.get(0));
vertex.setValue(new LongWritable(0));
String[] edges = record.get(1).toString().split(",");
for (int i = 0; i < edges.length; i++) {
long edge = Long.parseLong(edges[i]);
if (edge >= 0) {
vertex.addEdge(new LongWritable(Long.parseLong(edges[i])),
NullWritable.get());
}
}
LOG.info(record.toString());
context.addVertexRequest(vertex);
}
}
public static class LongSumCombiner extends
Combiner<LongWritable, LongWritable> {
@Override
public void combine(LongWritable vertexId, LongWritable combinedMessage,
LongWritable messageToCombine) throws IOException {
combinedMessage.set(combinedMessage.get() + messageToCombine.get());
}
}
public static class TopologySortAggregator extends
Aggregator<BooleanWritable> {
@SuppressWarnings("rawtypes")
@Override
public BooleanWritable createInitialValue(WorkerContext context)
throws IOException {
return new BooleanWritable(true);
}
@Override
public void aggregate(BooleanWritable value, Object item)
throws IOException {
boolean hasCycle = value.get();
boolean inDegreeNotZero = ((LongWritable) item).get() == 0 ? false : true;
value.set(hasCycle && inDegreeNotZero);
}
@Override
public void merge(BooleanWritable value, BooleanWritable partial)
throws IOException {
value.set(value.get() && partial.get());
}
@SuppressWarnings("rawtypes")
@Override
public boolean terminate(WorkerContext context, BooleanWritable value)
throws IOException {
if (context.getSuperstep() == 0) {
// since the initial aggregator value is true, and in superstep we don't
// do aggregate
return false;
}
return value.get();
}
}
public static void main(String[] args) throws IOException {
if (args.length != 2) {
System.out.println("Usage : <inputTable> <outputTable>");
System.exit(-1);
}
// 输入表形式为
// 0 1,2
// 1 3
// 2 3
// 3 -1
// 第一列为vertexid,第二列为该点边的destination vertexid,若值为-1,表示该点无出边
// 输出表形式为
// 0 0
// 1 1
// 1 2
// 2 3
// 第一列为supstep值,隐含了拓扑顺序,第二列为vertexid
// TopologySortAggregator用来判断图中是否有环
// 若输入的图有环,则当图中active的点入度都不为0时,迭代结束
// 用户可以通过输入表和输出表的记录数来判断一个有向图是否有环
GraphJob job = new GraphJob();
job.setGraphLoaderClass(TopologySortVertexReader.class);
job.setVertexClass(TopologySortVertex.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
job.setCombinerClass(LongSumCombiner.class);
job.setAggregatorClass(TopologySortAggregator.class);
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
}
线性回归
在统计学中,线性回归是用来确定两种或两种以上变量间的相互依赖关系的统计分析方法,与分类算法处理离散预测不同。
回归算法可对连续值类型进行预测。线性回归算法定义损失函数为样本集的最小平方误差之和,通过最小化损失函数求解权重矢量。
常用的解法是梯度下降法,流程如下:
[]()代码示例
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.Aggregator;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.io.DoubleWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;
import com.aliyun.odps.io.WritableRecord;
/**
* LineRegression input: y,x1,x2,x3,......
**/
public class LinearRegression {
public static class GradientWritable implements Writable {
Tuple lastTheta;
Tuple currentTheta;
Tuple tmpGradient;
LongWritable count;
DoubleWritable lost;
@Override
public void readFields(DataInput in) throws IOException {
lastTheta = new Tuple();
lastTheta.readFields(in);
currentTheta = new Tuple();
currentTheta.readFields(in);
tmpGradient = new Tuple();
tmpGradient.readFields(in);
count = new LongWritable();
count.readFields(in);
/* update 1: add a variable to store lost at every iteration */
lost = new DoubleWritable();
lost.readFields(in);
}
@Override
public void write(DataOutput out) throws IOException {
lastTheta.write(out);
currentTheta.write(out);
tmpGradient.write(out);
count.write(out);
lost.write(out);
}
}
public static class LinearRegressionVertex extends
Vertex<LongWritable, Tuple, NullWritable, NullWritable> {
@Override
public void compute(
ComputeContext<LongWritable, Tuple, NullWritable, NullWritable> context,
Iterable<NullWritable> messages) throws IOException {
context.aggregate(getValue());
}
}
public static class LinearRegressionVertexReader extends
GraphLoader<LongWritable, Tuple, NullWritable, NullWritable> {
@Override
public void load(LongWritable recordNum, WritableRecord record,
MutationContext<LongWritable, Tuple, NullWritable, NullWritable> context)
throws IOException {
LinearRegressionVertex vertex = new LinearRegressionVertex();
vertex.setId(recordNum);
vertex.setValue(new Tuple(record.getAll()));
context.addVertexRequest(vertex);
}
}
public static class LinearRegressionAggregator extends
Aggregator<GradientWritable> {
@SuppressWarnings("rawtypes")
@Override
public GradientWritable createInitialValue(WorkerContext context)
throws IOException {
if (context.getSuperstep() == 0) {
/* set initial value, all 0 */
GradientWritable grad = new GradientWritable();
grad.lastTheta = new Tuple();
grad.currentTheta = new Tuple();
grad.tmpGradient = new Tuple();
grad.count = new LongWritable(1);
grad.lost = new DoubleWritable(0.0);
int n = (int) Long.parseLong(context.getConfiguration()
.get("Dimension"));
for (int i = 0; i < n; i++) {
grad.lastTheta.append(new DoubleWritable(0));
grad.currentTheta.append(new DoubleWritable(0));
grad.tmpGradient.append(new DoubleWritable(0));
}
return grad;
} else
return (GradientWritable) context.getLastAggregatedValue(0);
}
public static double vecMul(Tuple value, Tuple theta) {
/* perform this partial computing: y(i)−hθ(x(i)) for each sample */
/* value denote a piece of sample and value(0) is y */
double sum = 0.0;
for (int j = 1; j < value.size(); j++)
sum += Double.parseDouble(value.get(j).toString())
* Double.parseDouble(theta.get(j).toString());
Double tmp = Double.parseDouble(theta.get(0).toString()) + sum
- Double.parseDouble(value.get(0).toString());
return tmp;
}
@Override
public void aggregate(GradientWritable gradient, Object value)
throws IOException {
/*
* perform on each vertex--each sample i:set theta(j) for each sample i
* for each dimension
*/
double tmpVar = vecMul((Tuple) value, gradient.currentTheta);
/*
* update 2:local worker aggregate(), perform like merge() below. This
* means the variable gradient denotes the previous aggregated value
*/
gradient.tmpGradient.set(0, new DoubleWritable(
((DoubleWritable) gradient.tmpGradient.get(0)).get() + tmpVar));
gradient.lost.set(Math.pow(tmpVar, 2));
/*
* calculate (y(i)−hθ(x(i))) x(i)(j) for each sample i for each
* dimension j
*/
for (int j = 1; j < gradient.tmpGradient.size(); j++)
gradient.tmpGradient.set(j, new DoubleWritable(
((DoubleWritable) gradient.tmpGradient.get(j)).get() + tmpVar
* Double.parseDouble(((Tuple) value).get(j).toString())));
}
@Override
public void merge(GradientWritable gradient, GradientWritable partial)
throws IOException {
/* perform SumAll on each dimension for all samples. */
Tuple master = (Tuple) gradient.tmpGradient;
Tuple part = (Tuple) partial.tmpGradient;
for (int j = 0; j < gradient.tmpGradient.size(); j++) {
DoubleWritable s = (DoubleWritable) master.get(j);
s.set(s.get() + ((DoubleWritable) part.get(j)).get());
}
gradient.lost.set(gradient.lost.get() + partial.lost.get());
}
@SuppressWarnings("rawtypes")
@Override
public boolean terminate(WorkerContext context, GradientWritable gradient)
throws IOException {
/*
* 1. calculate new theta 2. judge the diff between last step and this
* step, if smaller than the threshold, stop iteration
*/
gradient.lost = new DoubleWritable(gradient.lost.get()
/ (2 * context.getTotalNumVertices()));
/*
* we can calculate lost in order to make sure the algorithm is running on
* the right direction (for debug)
*/
System.out.println(gradient.count + " lost:" + gradient.lost);
Tuple tmpGradient = gradient.tmpGradient;
System.out.println("tmpGra" + tmpGradient);
Tuple lastTheta = gradient.lastTheta;
Tuple tmpCurrentTheta = new Tuple(gradient.currentTheta.size());
System.out.println(gradient.count + " terminate_start_last:" + lastTheta);
double alpha = 0.07; // learning rate
// alpha =
// Double.parseDouble(context.getConfiguration().get("Alpha"));
/* perform theta(j) = theta(j)-alpha*tmpGradient */
long M = context.getTotalNumVertices();
/*
* update 3: add (/M) on the code. The original code forget this step
*/
for (int j = 0; j < lastTheta.size(); j++) {
tmpCurrentTheta
.set(
j,
new DoubleWritable(Double.parseDouble(lastTheta.get(j)
.toString())
- alpha
/ M
* Double.parseDouble(tmpGradient.get(j).toString())));
}
System.out.println(gradient.count + " terminate_start_current:"
+ tmpCurrentTheta);
// judge if convergence is happening.
double diff = 0.00d;
for (int j = 0; j < gradient.currentTheta.size(); j++)
diff += Math.pow(((DoubleWritable) tmpCurrentTheta.get(j)).get()
- ((DoubleWritable) lastTheta.get(j)).get(), 2);
if (/*
* Math.sqrt(diff) < 0.00000000005d ||
*/Long.parseLong(context.getConfiguration().get("Max_Iter_Num")) == gradient.count
.get()) {
context.write(gradient.currentTheta.toArray());
return true;
}
gradient.lastTheta = tmpCurrentTheta;
gradient.currentTheta = tmpCurrentTheta;
gradient.count.set(gradient.count.get() + 1);
int n = (int) Long.parseLong(context.getConfiguration().get("Dimension"));
/*
* update 4: Important!!! Remember this step. Graph won't reset the
* initial value for global variables at the beginning of each iteration
*/
for (int i = 0; i < n; i++) {
gradient.tmpGradient.set(i, new DoubleWritable(0));
}
return false;
}
}
public static void main(String[] args) throws IOException {
GraphJob job = new GraphJob();
job.setGraphLoaderClass(LinearRegressionVertexReader.class);
job.setRuntimePartitioning(false);
job.setNumWorkers(3);
job.setVertexClass(LinearRegressionVertex.class);
job.setAggregatorClass(LinearRegressionAggregator.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
job.setMaxIteration(Integer.parseInt(args[2])); // Numbers of Iteration
job.setInt("Max_Iter_Num", Integer.parseInt(args[2]));
job.setInt("Dimension", Integer.parseInt(args[3])); // Dimension
job.setFloat("Alpha", Float.parseFloat(args[4])); // Learning rate
long start = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
三角形计数
三角形计数算法用于计算通过每个顶点的三角形个数。
算法实现的流程如下:
[]()代码示例
三角形计数算法的代码,如下所示:
import java.io.IOException;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.Edge;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;
import com.aliyun.odps.io.WritableRecord;
/**
* Compute the number of triangles passing through each vertex.
*
* The algorithm can be computed in three supersteps:
* I. Each vertex sends a message with its ID to all its outgoing
* neighbors.
* II. The incoming neighbors and outgoing neighbors are stored and
* send to outgoing neighbors.
* III. For each edge compute the intersection of the sets at destination
* vertex and sum them, then output to table.
*
* The triangle count is the sum of output table and divide by three since
* each triangle is counted three times.
*
**/
public class TriangleCount {
public static class TCVertex extends
Vertex<LongWritable, Tuple, NullWritable, Tuple> {
@Override
public void setup(
WorkerContext<LongWritable, Tuple, NullWritable, Tuple> context)
throws IOException {
// collect the outgoing neighbors
Tuple t = new Tuple();
if (this.hasEdges()) {
for (Edge<LongWritable, NullWritable> edge : this.getEdges()) {
t.append(edge.getDestVertexId());
}
}
this.setValue(t);
}
@Override
public void compute(
ComputeContext<LongWritable, Tuple, NullWritable, Tuple> context,
Iterable<Tuple> msgs) throws IOException {
if (context.getSuperstep() == 0L) {
// sends a message with its ID to all its outgoing neighbors
Tuple t = new Tuple();
t.append(getId());
context.sendMessageToNeighbors(this, t);
} else if (context.getSuperstep() == 1L) {
// store the incoming neighbors
for (Tuple msg : msgs) {
for (Writable item : msg.getAll()) {
if (!this.getValue().getAll().contains((LongWritable)item)) {
this.getValue().append((LongWritable)item);
}
}
}
// send both incoming and outgoing neighbors to all outgoing neighbors
context.sendMessageToNeighbors(this, getValue());
} else if (context.getSuperstep() == 2L) {
// count the sum of intersection at each edge
long count = 0;
for (Tuple msg : msgs) {
for (Writable id : msg.getAll()) {
if (getValue().getAll().contains(id)) {
count ++;
}
}
}
// output to table
context.write(getId(), new LongWritable(count));
this.voteToHalt();
}
}
}
public static class TCVertexReader extends
GraphLoader<LongWritable, Tuple, NullWritable, Tuple> {
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, Tuple, NullWritable, Tuple> context)
throws IOException {
TCVertex vertex = new TCVertex();
vertex.setId((LongWritable) record.get(0));
String[] edges = record.get(1).toString().split(",");
for (int i = 0; i < edges.length; i++) {
try {
long destID = Long.parseLong(edges[i]);
vertex.addEdge(new LongWritable(destID), NullWritable.get());
} catch(NumberFormatException nfe) {
System.err.println("Ignore " + nfe);
}
}
context.addVertexRequest(vertex);
}
}
public static void main(String[] args) throws IOException {
if (args.length < 2) {
System.out.println("Usage: <input> <output>");
System.exit(-1);
}
GraphJob job = new GraphJob();
job.setGraphLoaderClass(TCVertexReader.class);
job.setVertexClass(TCVertex.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
long startTime = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - startTime) / 1000.0 + " seconds");
}
}
输入点表示例
输入点表的代码,如下所示:
import java.io.IOException;
import com.aliyun.odps.conf.Configuration;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.VertexResolver;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.VertexChanges;
import com.aliyun.odps.graph.Edge;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.WritableComparable;
import com.aliyun.odps.io.WritableRecord;
/**
* 本示例是用于展示,对于不同类型的数据类型,如何编写图作业程序载入数据。主要展示GraphLoader和
* VertexResolver的配合完成图的构建。
*
* ODPS Graph的作业输入都为ODPS的Table,假设作业输入有两张表,一张存储点的信息,一张存储边的信息。
* 存储点信息的表的格式,如:
* +------------------------+
* | VertexID | VertexValue |
* +------------------------+
* | id0| 9|
* +------------------------+
* | id1| 7|
* +------------------------+
* | id2| 8|
* +------------------------+
*
* 存储边信息的表的格式,如
* +-----------------------------------+
* | VertexID | DestVertexID| EdgeValue|
* +-----------------------------------+
* | id0| id1| 1|
* +-----------------------------------+
* | id0| id2| 2|
* +-----------------------------------+
* | id2| id1| 3|
* +-----------------------------------+
*
* 结合两张表的数据,表示id0有两条出边,分别指向id1和id2;id2有一条出边,指向id1;id1没有出边。
*
* 对于此种类型的数据,在GraphLoader::load(LongWritable, Record, MutationContext)
* ,可以使用 MutationContext#addVertexRequest(Vertex)向图中请求添加点,使用
* link MutationContext#addEdgeRequest(WritableComparable, Edge)向图中请求添加边,然后,在
* link VertexResolver#resolve(WritableComparable, Vertex, VertexChanges, boolean)
* 中,将load 方法中添加的点和边,合并到一个Vertex对象中,作为返回值,添加到最后参与计算的图中。
*
**/
public class VertexInputFormat {
private final static String EDGE_TABLE = "edge.table";
/**
* 将Record解释为Vertex和Edge,每个Record根据其来源,表示一个Vertex或者一条Edge。
*
* 类似于com.aliyun.odps.mapreduce.Mapper#map
* ,输入Record,生成键值对,此处的键是Vertex的ID,
* 值是Vertex或Edge,通过上下文Context写出,这些键值对会在LoadingVertexResolver出根据Vertex的ID汇总。
*
* 注意:此处添加的点或边只是根据Record内容发出的请求,并不是最后参与计算的点或边,只有在随后的VertexResolver
* 中添加的点或边才参与计算。
**/
public static class VertexInputLoader extends
GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> {
private boolean isEdgeData;
/**
* 配置VertexInputLoader。
*
* @param conf
* 作业的配置参数,在main中使用GraphJob配置的,或者在console中set的
* @param workerId
* 当前工作的worker的序号,从0开始,可以用于构造唯一的vertex id
* @param inputTableInfo
* 当前worker载入的输入表信息,可以用于确定当前输入是哪种类型的数据,即Record的格式
**/
@Override
public void setup(Configuration conf, int workerId, TableInfo inputTableInfo) {
isEdgeData = conf.get(EDGE_TABLE).equals(inputTableInfo.getTableName());
}
/**
* 根据Record中的内容,解析为对应的边,并请求添加到图中。
*
* @param recordNum
* 记录序列号,从1开始,每个worker上单独计数
* @param record
* 输入表中的记录,三列,分别表示初点、终点、边的权重
* @param context
* 上下文,请求将解释后的边添加到图中
**/
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
throws IOException {
if (isEdgeData) {
/**
* 数据来源于存储边信息的表。
*
* 1、第一列表示初始点的ID
**/
LongWritable sourceVertexID = (LongWritable) record.get(0);
/**
* 2、第二列表示终点的ID
**/
LongWritable destinationVertexID = (LongWritable) record.get(1);
/**
* 3、地三列表示边的权重
**/
LongWritable edgeValue = (LongWritable) record.get(2);
/**
* 4、创建边,由终点ID和边的权重组成
**/
Edge<LongWritable, LongWritable> edge = new Edge<LongWritable, LongWritable>(
destinationVertexID, edgeValue);
/**
* 5、请求给初始点添加边
**/
context.addEdgeRequest(sourceVertexID, edge);
/**
* 6、如果每条Record表示双向边,重复4与5 Edge<LongWritable, LongWritable> edge2 = new
* Edge<LongWritable, LongWritable>( sourceVertexID, edgeValue);
* context.addEdgeRequest(destinationVertexID, edge2);
**/
} else {
/**
* 数据来源于存储点信息的表。
*
* 1、第一列表示点的ID
**/
LongWritable vertexID = (LongWritable) record.get(0);
/**
* 2、第二列表示点的值
**/
LongWritable vertexValue = (LongWritable) record.get(1);
/**
* 3、创建点,由点的ID和点的值组成
**/
MyVertex vertex = new MyVertex();
/**
* 4、初始化点
**/
vertex.setId(vertexID);
vertex.setValue(vertexValue);
/**
* 5、请求添加点
**/
context.addVertexRequest(vertex);
}
}
}
/**
* 汇总GraphLoader::load(LongWritable, Record, MutationContext)生成的键值对,类似于
* com.aliyun.odps.mapreduce.Reducer中的reduce。对于唯一的Vertex ID,所有关于这个ID上
* 添加\删除、点\边的行为都会放在VertexChanges中。
*
* 注意:此处并不只针对load方法中添加的有冲突的点或边才调用(冲突是指添加多个相同Vertex对象,添加重复边等),
* 所有在load方法中请求生成的ID都会在此处被调用。
**/
public static class LoadingResolver extends
VertexResolver<LongWritable, LongWritable, LongWritable, LongWritable> {
/**
* 处理关于一个ID的添加或删除、点或边的请求。
*
* VertexChanges有四个接口,分别与MutationContext的四个接口对应:
* VertexChanges::getAddedVertexList()与
* MutationContext::addVertexRequest(Vertex)对应,
* 在load方法中,请求添加的ID相同的Vertex对象,会被汇总在返回的List中
* VertexChanges::getAddedEdgeList()与
* MutationContext::addEdgeRequest(WritableComparable, Edge)
* 对应,请求添加的初始点ID相同的Edge对象,会被汇总在返回的List中
* VertexChanges::getRemovedVertexCount()与
* MutationContext::removeVertexRequest(WritableComparable)
* 对应,请求删除的ID相同的Vertex,汇总的请求删除的次数作为返回值
* VertexChanges#getRemovedEdgeList()与
* MutationContext#removeEdgeRequest(WritableComparable, WritableComparable)
* 对应,请求删除的初始点ID相同的Edge对象,会被汇总在返回的List中
*
* 用户通过处理关于这个ID的变化,通过返回值声明此ID是否参与计算,如果返回的Vertex不为null,
* 则此ID会参与随后的计算,如果返回null,则不会参与计算。
*
* @param vertexId
* 请求添加的点的ID,或请求添加的边的初点ID
* @param vertex
* 已存在的Vertex对象,数据载入阶段,始终为null
* @param vertexChanges
* 此ID上的请求添加\删除、点\边的集合
* @param hasMessages
* 此ID是否有输入消息,数据载入阶段,始终为false
**/
@Override
public Vertex<LongWritable, LongWritable, LongWritable, LongWritable> resolve(
LongWritable vertexId,
Vertex<LongWritable, LongWritable, LongWritable, LongWritable> vertex,
VertexChanges<LongWritable, LongWritable, LongWritable, LongWritable> vertexChanges,
boolean hasMessages) throws IOException {
/**
* 1、获取Vertex对象,作为参与计算的点。
**/
MyVertex computeVertex = null;
if (vertexChanges.getAddedVertexList() == null
|| vertexChanges.getAddedVertexList().isEmpty()) {
computeVertex = new MyVertex();
computeVertex.setId(vertexId);
} else {
/**
* 此处假设存储点信息的表中,每个Record表示唯一的点。
**/
computeVertex = (MyVertex) vertexChanges.getAddedVertexList().get(0);
}
/**
* 2、将请求给此点添加的边,添加到Vertex对象中,如果数据有重复的可能,根据算法需要决定是否去重。
**/
if (vertexChanges.getAddedEdgeList() != null) {
for (Edge<LongWritable, LongWritable> edge : vertexChanges
.getAddedEdgeList()) {
computeVertex.addEdge(edge.getDestVertexId(), edge.getValue());
}
}
/**
* 3、将Vertex对象返回,添加到最终的图中参与计算。
**/
return computeVertex;
}
}
/**
* 确定参与计算的Vertex的行为。
*
**/
public static class MyVertex extends
Vertex<LongWritable, LongWritable, LongWritable, LongWritable> {
/**
* 将vertex的边,按照输入表的格式再写到结果表。输入表与输出表的格式和数据都相同。
*
* @param context
* 运行时上下文
* @param messages
* 输入消息
**/
@Override
public void compute(
ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context,
Iterable<LongWritable> messages) throws IOException {
/**
* 将点的ID和值,写到存储点的结果表
**/
context.write("vertex", getId(), getValue());
/**
* 将点的边,写到存储边的结果表
**/
if (hasEdges()) {
for (Edge<LongWritable, LongWritable> edge : getEdges()) {
context.write("edge", getId(), edge.getDestVertexId(),
edge.getValue());
}
}
/**
* 只迭代一轮
**/
voteToHalt();
}
}
/**
* @param args
* @throws IOException
*/
public static void main(String[] args) throws IOException {
if (args.length < 4) {
throw new IOException(
"Usage: VertexInputFormat <vertex input> <edge input> <vertex output> <edge output>");
}
/**
* GraphJob用于对Graph作业进行配置
*/
GraphJob job = new GraphJob();
/**
* 1、指定输入的图数据,并指定存储边数据所在的表。
*/
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addInput(TableInfo.builder().tableName(args[1]).build());
job.set(EDGE_TABLE, args[1]);
/**
* 2、指定载入数据的方式,将Record解释为Edge,此处类似于map,生成的 key为vertex的ID,value为edge。
*/
job.setGraphLoaderClass(VertexInputLoader.class);
/**
* 3、指定载入数据阶段,生成参与计算的vertex。此处类似于reduce,将map 生成的edge合并成一个vertex。
*/
job.setLoadingVertexResolverClass(LoadingResolver.class);
/**
* 4、指定参与计算的vertex的行为。每轮迭代执行vertex.compute方法。
*/
job.setVertexClass(MyVertex.class);
/**
* 5、指定图作业的输出表,将计算生成的结果写到结果表中。
*/
job.addOutput(TableInfo.builder().tableName(args[2]).label("vertex").build());
job.addOutput(TableInfo.builder().tableName(args[3]).label("edge").build());
/**
* 6、提交作业执行。
*/
job.run();
}
}
输入边表示例
输入边表的代码,如下所示:
import java.io.IOException;
import com.aliyun.odps.conf.Configuration;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.graph.ComputeContext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.VertexResolver;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.VertexChanges;
import com.aliyun.odps.graph.Edge;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.WritableComparable;
import com.aliyun.odps.io.WritableRecord;
/**
* 本示例是用于展示,对于不同类型的数据类型,如何编写图作业程序载入数据。主要展示GraphLoader和
* VertexResolver的配合完成图的构建。
*
* ODPS Graph的作业输入都为ODPS的Table,假设作业输入有两张表,一张存储点的信息,一张存储边的信息。
* 存储点信息的表的格式,如:
* +------------------------+
* | VertexID | VertexValue |
* +------------------------+
* | id0| 9|
* +------------------------+
* | id1| 7|
* +------------------------+
* | id2| 8|
* +------------------------+
*
* 存储边信息的表的格式,如
* +-----------------------------------+
* | VertexID | DestVertexID| EdgeValue|
* +-----------------------------------+
* | id0| id1| 1|
* +-----------------------------------+
* | id0| id2| 2|
* +-----------------------------------+
* | id2| id1| 3|
* +-----------------------------------+
*
* 结合两张表的数据,表示id0有两条出边,分别指向id1和id2;id2有一条出边,指向id1;id1没有出边。
*
* 对于此种类型的数据,在GraphLoader::load(LongWritable, Record, MutationContext)
* ,可以使用 MutationContext#addVertexRequest(Vertex)向图中请求添加点,使用
* link MutationContext#addEdgeRequest(WritableComparable, Edge)向图中请求添加边,然后,在
* link VertexResolver#resolve(WritableComparable, Vertex, VertexChanges, boolean)
* 中,将load 方法中添加的点和边,合并到一个Vertex对象中,作为返回值,添加到最后参与计算的图中。
*
**/
public class VertexInputFormat {
private final static String EDGE_TABLE = "edge.table";
/**
* 将Record解释为Vertex和Edge,每个Record根据其来源,表示一个Vertex或者一条Edge。
* <p>
* 类似于com.aliyun.odps.mapreduce.Mapper#map
* ,输入Record,生成键值对,此处的键是Vertex的ID,
* 值是Vertex或Edge,通过上下文Context写出,这些键值对会在LoadingVertexResolver出根据Vertex的ID汇总。
*
* 注意:此处添加的点或边只是根据Record内容发出的请求,并不是最后参与计算的点或边,只有在随后的VertexResolver
* 中添加的点或边才参与计算。
**/
public static class VertexInputLoader extends
GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> {
private boolean isEdgeData;
/**
* 配置VertexInputLoader。
*
* @param conf
* 作业的配置参数,在main中使用GraphJob配置的,或者在console中set的
* @param workerId
* 当前工作的worker的序号,从0开始,可以用于构造唯一的vertex id
* @param inputTableInfo
* 当前worker载入的输入表信息,可以用于确定当前输入是哪种类型的数据,即Record的格式
**/
@Override
public void setup(Configuration conf, int workerId, TableInfo inputTableInfo) {
isEdgeData = conf.get(EDGE_TABLE).equals(inputTableInfo.getTableName());
}
/**
* 根据Record中的内容,解析为对应的边,并请求添加到图中。
*
* @param recordNum
* 记录序列号,从1开始,每个worker上单独计数
* @param record
* 输入表中的记录,三列,分别表示初点、终点、边的权重
* @param context
* 上下文,请求将解释后的边添加到图中
**/
@Override
public void load(
LongWritable recordNum,
WritableRecord record,
MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context)
throws IOException {
if (isEdgeData) {
/**
* 数据来源于存储边信息的表。
*
* 1、第一列表示初始点的ID
**/
LongWritable sourceVertexID = (LongWritable) record.get(0);
/**
* 2、第二列表示终点的ID
**/
LongWritable destinationVertexID = (LongWritable) record.get(1);
/**
* 3、地三列表示边的权重
**/
LongWritable edgeValue = (LongWritable) record.get(2);
/**
* 4、创建边,由终点ID和边的权重组成
**/
Edge<LongWritable, LongWritable> edge = new Edge<LongWritable, LongWritable>(
destinationVertexID, edgeValue);
/**
* 5、请求给初始点添加边
**/
context.addEdgeRequest(sourceVertexID, edge);
/**
* 6、如果每条Record表示双向边,重复4与5 Edge<LongWritable, LongWritable> edge2 = new
* Edge<LongWritable, LongWritable>( sourceVertexID, edgeValue);
* context.addEdgeRequest(destinationVertexID, edge2);
**/
} else {
/**
* 数据来源于存储点信息的表。
*
* 1、第一列表示点的ID
**/
LongWritable vertexID = (LongWritable) record.get(0);
/**
* 2、第二列表示点的值
**/
LongWritable vertexValue = (LongWritable) record.get(1);
/**
* 3、创建点,由点的ID和点的值组成
**/
MyVertex vertex = new MyVertex();
/**
* 4、初始化点
**/
vertex.setId(vertexID);
vertex.setValue(vertexValue);
/**
* 5、请求添加点
**/
context.addVertexRequest(vertex);
}
}
}
/**
* 汇总GraphLoader::load(LongWritable, Record, MutationContext)生成的键值对,类似于
* com.aliyun.odps.mapreduce.Reducer中的reduce。对于唯一的Vertex ID,所有关于这个ID上
* 添加\删除、点\边的行为都会放在VertexChanges中。
*
* 注意:此处并不只针对load方法中添加的有冲突的点或边才调用(冲突是指添加多个相同Vertex对象,添加重复边等),
* 所有在load方法中请求生成的ID都会在此处被调用。
**/
public static class LoadingResolver extends
VertexResolver<LongWritable, LongWritable, LongWritable, LongWritable> {
/**
* 处理关于一个ID的添加或删除、点或边的请求。
*
* VertexChanges有四个接口,分别与MutationContext的四个接口对应:
* VertexChanges::getAddedVertexList()与
* MutationContext::addVertexRequest(Vertex)对应,
* 在load方法中,请求添加的ID相同的Vertex对象,会被汇总在返回的List中
* VertexChanges::getAddedEdgeList()与
* MutationContext::addEdgeRequest(WritableComparable, Edge)
* 对应,请求添加的初始点ID相同的Edge对象,会被汇总在返回的List中
* VertexChanges::getRemovedVertexCount()与
* MutationContext::removeVertexRequest(WritableComparable)
* 对应,请求删除的ID相同的Vertex,汇总的请求删除的次数作为返回值
* VertexChanges#getRemovedEdgeList()与
* MutationContext#removeEdgeRequest(WritableComparable, WritableComparable)
* 对应,请求删除的初始点ID相同的Edge对象,会被汇总在返回的List中
*
* 用户通过处理关于这个ID的变化,通过返回值声明此ID是否参与计算,如果返回的Vertex不为null,
* 则此ID会参与随后的计算,如果返回null,则不会参与计算。
*
* @param vertexId
* 请求添加的点的ID,或请求添加的边的初点ID
* @param vertex
* 已存在的Vertex对象,数据载入阶段,始终为null
* @param vertexChanges
* 此ID上的请求添加\删除、点\边的集合
* @param hasMessages
* 此ID是否有输入消息,数据载入阶段,始终为false
**/
@Override
public Vertex<LongWritable, LongWritable, LongWritable, LongWritable> resolve(
LongWritable vertexId,
Vertex<LongWritable, LongWritable, LongWritable, LongWritable> vertex,
VertexChanges<LongWritable, LongWritable, LongWritable, LongWritable> vertexChanges,
boolean hasMessages) throws IOException {
/**
* 1、获取Vertex对象,作为参与计算的点。
**/
MyVertex computeVertex = null;
if (vertexChanges.getAddedVertexList() == null
|| vertexChanges.getAddedVertexList().isEmpty()) {
computeVertex = new MyVertex();
computeVertex.setId(vertexId);
} else {
/**
* 此处假设存储点信息的表中,每个Record表示唯一的点。
**/
computeVertex = (MyVertex) vertexChanges.getAddedVertexList().get(0);
}
/**
* 2、将请求给此点添加的边,添加到Vertex对象中,如果数据有重复的可能,根据算法需要决定是否去重。
**/
if (vertexChanges.getAddedEdgeList() != null) {
for (Edge<LongWritable, LongWritable> edge : vertexChanges
.getAddedEdgeList()) {
computeVertex.addEdge(edge.getDestVertexId(), edge.getValue());
}
}
/**
* 3、将Vertex对象返回,添加到最终的图中参与计算。
**/
return computeVertex;
}
}
/**
* 确定参与计算的Vertex的行为。
*
**/
public static class MyVertex extends
Vertex<LongWritable, LongWritable, LongWritable, LongWritable> {
/**
* 将vertex的边,按照输入表的格式再写到结果表。输入表与输出表的格式和数据都相同。
*
* @param context
* 运行时上下文
* @param messages
* 输入消息
**/
@Override
public void compute(
ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context,
Iterable<LongWritable> messages) throws IOException {
/**
* 将点的ID和值,写到存储点的结果表
**/
context.write("vertex", getId(), getValue());
/**
* 将点的边,写到存储边的结果表
**/
if (hasEdges()) {
for (Edge<LongWritable, LongWritable> edge : getEdges()) {
context.write("edge", getId(), edge.getDestVertexId(),
edge.getValue());
}
}
/**
* 只迭代一轮
**/
voteToHalt();
}
}
/**
* @param args
* @throws IOException
*/
public static void main(String[] args) throws IOException {
if (args.length < 4) {
throw new IOException(
"Usage: VertexInputFormat <vertex input> <edge input> <vertex output> <edge output>");
}
/**
* GraphJob用于对Graph作业进行配置
*/
GraphJob job = new GraphJob();
/**
* 1、指定输入的图数据,并指定存储边数据所在的表。
*/
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addInput(TableInfo.builder().tableName(args[1]).build());
job.set(EDGE_TABLE, args[1]);
/**
* 2、指定载入数据的方式,将Record解释为Edge,此处类似于map,生成的 key为vertex的ID,value为edge。
*/
job.setGraphLoaderClass(VertexInputLoader.class);
/**
* 3、指定载入数据阶段,生成参与计算的vertex。此处类似于reduce,将map 生成的edge合并成一个vertex。
*/
job.setLoadingVertexResolverClass(LoadingResolver.class);
/**
* 4、指定参与计算的vertex的行为。每轮迭代执行vertex.compute方法。
*/
job.setVertexClass(MyVertex.class);
/**
* 5、指定图作业的输出表,将计算生成的结果写到结果表中。
*/
job.addOutput(TableInfo.builder().tableName(args[2]).label("vertex").build());
job.addOutput(TableInfo.builder().tableName(args[3]).label("edge").build());
/**
* 6、提交作业执行。
*/
job.run();
}
}