This topic describes the implementation mechanism and related API operations of Aggregator. It also describes how to apply Aggregator by using k-means clustering.
Aggregator is a common feature in MaxCompute Graph jobs. It is most suited to handle machine learning issues. In MaxCompute Graph, Aggregator is used to aggregate and process global information.
Implementation mechanism
- One part is implemented on all workers in distributed mode.
- The other part is implemented only on the worker where the Aggregator owner resides in single vertex mode.
- When each worker starts, it executes createStartupValue to create AggregatorValue.
- Before each iteration starts, each worker executes createInitialValue to initialize AggregatorValue for the iteration.
- In an iteration, each vertex uses
context.aggregate()
to callaggregate()
to implement partial iteration in the worker. - Each worker sends the partial iteration result to the worker where the Aggregator owner resides.
- The worker where the Aggregator owner resides executes
merge
multiple times to implement global aggregation. - The worker where the Aggregator owner resides executes
terminate
to process the global aggregation result and determines whether to end the iteration.
API operations
- createStartupValue(context)
This API operation is performed once on all workers before each superstep starts. It is used to initialize
AggregatorValue
. In the first superstep (superstep 0),WorkerContext.getLastAggregatedValue()
orComputeContext.getLastAggregatedValue()
is called to obtain the initializedAggregatorValue
object. - createInitialValue(context)
This API operation is performed once on all workers when each superstep starts. It is used to initialize
AggregatorValue
for the current iteration. In most cases,WorkerContext.getLastAggregatedValue()
is called to obtain the result of the previous iteration. Then, the partial initialization is implemented. - aggregate(value, item)
This API operation is also performed on all workers. It is triggered by an explicit call to
ComputeContext#aggregate(item)
, while the preceding two API operations are automatically called by the framework. This API operation is used to implement partial aggregation. The first parametervalue
indicates the aggregation result of the worker in the current superstep. The initial value is the object that is returned bycreateInitialValue
. The second parameter is passed in whenComputeContext#aggregate(item)
is called by using your code. In this API operation,item
is used to updatevalue
for aggregation in most cases. After all theaggregate
operations are performed, the obtainedvalue
is the partial aggregation result of the worker. The result is then sent by the framework to the worker where the Aggregator owner resides. - merge(value, partial)
This API operation is performed on the worker where the Aggregator owner resides. It is used to merge partial aggregation results of workers to obtain the global aggregation object. Similar to
aggregate
,value
in this API operation indicates the aggregated results, andpartial
indicates objects that you want to aggregate.partial
is used to updatevalue
.For example, three workers, w0, w1, and w2 generate partial aggregation results p0, p1, and p2. If p1, p0, and p2 are sent in sequence to the worker where the Aggregator owner resides, the
merge
operations are performed in the following sequence:merge(p1, p0)
is first executed to aggregate p1 and p0 as p1.merge(p1, p2)
is executed to aggregate p1 and p2 as p1. p1 is the global aggregation result in this superstep.
Therefore, if only one worker exists, the
merge
method is not required. In this case,merge()
is not called. - terminate(context, value)
After the worker where the Aggregator owner resides executes
merge()
, the framework callsterminate(context, value)
to perform the final processing. The second parametervalue
indicates the global aggregation result that is obtained by callingmerge()
. The global aggregation result can be further modified in this API operation. Afterterminate()
is executed, the framework distributes the global aggregation object to all workers for the next superstep. If true is returned forterminate()
, iteration is ended for the entire job. Otherwise, the iteration continues. If true is returned after convergence is completed, jobs are immediately ended. This applies to machine learning scenarios.
K-means clustering example
- GraphLoader
GraphLoader is used to load an input table and convert it to vertices or edges of a graph. In this example, each row of data in the input table is a sample, each sample constructs a vertex, and vertex values are used to store samples.
A writable classKmeansValue
is defined as the value type of a vertex.public static class KmeansValue implements Writable { DenseVector sample; public KmeansValue() { } public KmeansValue(DenseVector v) { this.sample = v; } @Override public void write(DataOutput out) throws IOException { wirteForDenseVector(out, sample); } @Override public void readFields(DataInput in) throws IOException { sample = readFieldsForDenseVector(in); } }
A
DenseVector
object is encapsulated inKmeansValue
to store a sample. TheDenseVector
type originates from matrix-toolkits-java.wirteForDenseVector()
andreadFieldsForDenseVector()
are used for serialization and deserialization.CustomKmeansReader
code:public static class KmeansReader extends GraphLoader<LongWritable, KmeansValue, NullWritable, NullWritable> { @Override public void load( LongWritable recordNum, WritableRecord record, MutationContext<LongWritable, KmeansValue, NullWritable, NullWritable> context) throws IOException { KmeansVertex v = new KmeansVertex(); v.setId(recordNum); int n = record.size(); DenseVector dv = new DenseVector(n); for (int i = 0; i < n; i++) { dv.set(i, ((DoubleWritable)record.get(i)).get()); } v.setValue(new KmeansValue(dv)); context.addVertexRequest(v); } }
In
KmeansReader
, a vertex is created when each row of data (a record) is read.recordNum
is used as the vertex ID, and therecord
content is converted to aDenseVector
object and encapsulated inVertexValue
. - Vertex
Custom
KmeansVertex
code: The logic of the preceding code is to implement partial aggregation for samples maintained for each iteration. For more information about the logic, see the implementation of Aggregator in the following section.public static class KmeansVertex extends Vertex<LongWritable, KmeansValue, NullWritable, NullWritable> { @Override public void compute( ComputeContext<LongWritable, KmeansValue, NullWritable, NullWritable> context, Iterable<NullWritable> messages) throws IOException { context.aggregate(getValue()); } }
- Aggregator
The main logic of k-means is concentrated on Aggregator. Custom
KmeansAggrValue
is used to maintain the content you want to aggregate and distribute.public static class KmeansAggrValue implements Writable { DenseMatrix centroids; DenseMatrix sums; // used to recalculate new centroids DenseVector counts; // used to recalculate new centroids @Override public void write(DataOutput out) throws IOException { wirteForDenseDenseMatrix(out, centroids); wirteForDenseDenseMatrix(out, sums); wirteForDenseVector(out, counts); } @Override public void readFields(DataInput in) throws IOException { centroids = readFieldsForDenseMatrix(in); sums = readFieldsForDenseMatrix(in); counts = readFieldsForDenseVector(in); } }
In the preceding code, three objects are maintained inKmeansAggrValue
:centroids
: indicates the existing K centers. If the sample is m-dimensional,centroids
is a matrix of K × m.sums
: indicates a matrix of the same size ascentroids
. Each element records the sum of a specific dimension of the sample that is closest to a specific center. For example,sums(i,j)
indicates the sum of dimension j of the sample that is closest to center i.counts
is a K-dimensional vector. It records the number of samples that are closest to each center.counts
is used withsums
to calculate a new center, which is the main content to be aggregated.
KmeansAggregator
is a custom Aggregator implementation class. The following section describes the implementation of the preceding API operations:- Implementation of
createStartupValue()
public static class KmeansAggregator extends Aggregator<KmeansAggrValue> { public KmeansAggrValue createStartupValue(WorkerContext context) throws IOException { KmeansAggrValue av = new KmeansAggrValue(); byte[] centers = context.readCacheFile("centers"); String lines[] = new String(centers).split("\n"); int rows = lines.length; int cols = lines[0].split(",").length; // assumption rows >= 1 av.centroids = new DenseMatrix(rows, cols); av.sums = new DenseMatrix(rows, cols); av.sums.zero(); av.counts = new DenseVector(rows); av.counts.zero(); for (int i = 0; i < lines.length; i++) { String[] ss = lines[i].split(","); for (int j = 0; j < ss.length; j++) { av.centroids.set(i, j, Double.valueOf(ss[j])); } } return av; } }
This method initializes a
KmeansAggrValue
object, reads the initial center from thecenters
file, and assigns a value tocentroids
. The initial values ofsums
andcounts
are 0. - Implementation of
createInitialValue()
@Override public KmeansAggrValue createInitialValue(WorkerContext context) throws IOException { KmeansAggrValue av = (KmeansAggrValue)context.getLastAggregatedValue(0); // reset for next iteration av.sums.zero(); av.counts.zero(); return av; }
This method first obtains
KmeansAggrValue
of the previous iteration and clears the values ofsums
andcounts
. Only thecentroids
value of the previous iteration is retained. - Implementation of
aggregate()
@Override public void aggregate(KmeansAggrValue value, Object item) throws IOException { DenseVector sample = ((KmeansValue)item).sample; // find the nearest centroid int min = findNearestCentroid(value.centroids, sample); // update sum and count for (int i = 0; i < sample.size(); i ++) { value.sums.add(min, i, sample.get(i)); } value.counts.add(min, 1.0d); }
This method calls
findNearestCentroid()
to find the index of the center closest to the sampleitem
, usessums
to add up all dimensions, and increments the value ofcounts
by 1.
The preceding three methods are executed on all workers to implement partial aggregation. The following methods can be used to implement global aggregation on the worker where the Aggregator owner resides:- Implementation of
merge()
@Override public void merge(KmeansAggrValue value, KmeansAggrValue partial) throws IOException { value.sums.add(partial.sums); value.counts.add(partial.counts); }
In the preceding example, the implementation logic of
merge
is to add the values ofsums
andcounts
aggregated by each worker. - Implementation of
terminate()
@Override public boolean terminate(WorkerContext context, KmeansAggrValue value) throws IOException { // Calculate the new means to be the centroids (original sums) DenseMatrix newCentriods = calculateNewCentroids(value.sums, value.counts, value.centroids); // print old centroids and new centroids for debugging System.out.println("\nsuperstep: " + context.getSuperstep() + "\nold centriod:\n" + value.centroids + " new centriod:\n" + newCentriods); boolean converged = isConverged(newCentriods, value.centroids, 0.05d); System.out.println("superstep: " + context.getSuperstep() + "/" + (context.getMaxIteration() - 1) + " converged: " + converged); if (converged || context.getSuperstep() == context.getMaxIteration() - 1) { // converged or reach max iteration, output centriods for (int i = 0; i < newCentriods.numRows(); i++) { Writable[] centriod = new Writable[newCentriods.numColumns()]; for (int j = 0; j < newCentriods.numColumns(); j++) { centriod[j] = new DoubleWritable(newCentriods.get(i, j)); } context.write(centriod); } // true means to terminate iteration return true; } // update centriods value.centroids.set(newCentriods); // false means to continue iteration return false; }
In the preceding example,
teminate()
callscalculateNewCentroids()
based onsums
andcounts
to calculate the average value and obtain a new center. Then,isConverged()
is called to check whether the center is converged based on the Euclidean distance between the new and old centers. If the number of convergences or iterations reaches the upper limit, the new center is generated, and true is returned to end the iteration. Otherwise, the center is updated, and false is returned to continue the iteration.
main
methodThemain
method is used to constructGraphJob
, configure related settings, and submit a job.public static void main(String[] args) throws IOException { if (args.length < 2) printUsage(); GraphJob job = new GraphJob(); job.setGraphLoaderClass(KmeansReader.class); job.setRuntimePartitioning(false); job.setVertexClass(KmeansVertex.class); job.setAggregatorClass(KmeansAggregator.class); job.addInput(TableInfo.builder().tableName(args[0]).build()); job.addOutput(TableInfo.builder().tableName(args[1]).build()); // default max iteration is 30 job.setMaxIteration(30); if (args.length >= 3) job.setMaxIteration(Integer.parseInt(args[2])); long start = System.currentTimeMillis(); job.run(); System.out.println("Job Finished in " + (System.currentTimeMillis() - start) / 1000.0 + " seconds"); }
Note Ifjob.setRuntimePartitioning
is set to false, data loaded by each worker is not partitioned based on the partitioner. Data is loaded and maintained by the same worker.