K-means アルゴリズムは、典型的なクラスタリングアルゴリズムです。
これは、空間内の k 個の頂点を中心として使用し、それらに最も近い頂点をグループ化することによってクラスタリングを実行します。 クラスタリングの中心値は、最適なクラスタリング結果が得られるまで、反復しながら連続的に更新されます。
サンプルの集合を k個のクラスに分割するため、アルゴリズムは以下のように動作します。
- k 個のクラスの中心の初期値を選択します。
- 任意のサンプルから k 個の中心までの距離を反復 i で 計算し、サンプルを最も近い中心のクラスにグループ化します。
- 平均や他の方法を使ってクラスの中心値を更新します。
- すべての k 個のクラスター中心について、反復後に更新された値が変化しないままであるか、しきい値よりも小さい場合、反復は終了します。 そうでない場合、反復は継続します。
サンプルコード
K-means クラスタリングアルゴリズムのコードは以下のとおりです。
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.logging.log4j.Logger;
import com.aliyun.odps.io.WritableRecord;
import com.aliyun.odps.graph.Aggregator;
Import com. aliyun. ODPS. graph. computercontext;
import com.aliyun.odps.graph.GraphJob;
import com.aliyun.odps.graph.GraphLoader;
import com.aliyun.odps.graph.MutationContext;
import com.aliyun.odps.graph.Vertex;
import com.aliyun.odps.graph.WorkerContext;
import com.aliyun.odps.io.DoubleWritable;
import com.aliyun.odps.io.LongWritable;
import com.aliyun.odps.io.NullWritable;
import com.aliyun.odps.data.TableInfo;
import com.aliyun.odps.io.Text;
import com.aliyun.odps.io.Tuple;
import com.aliyun.odps.io.Writable;
public class Kmeans {
private final static Logger LOG = Logger.getLogger(Kmeans.class);
public static class KmeansVertex extends
Vertex<Text, Tuple, NullWritable, NullWritable> {
@ Override
public void compute(
ComputeContext<Text, Tuple, NullWritable, NullWritable> context,
Iterable<NullWritable> messages) throws IOException {
context.aggregate(getValue());
}
public static class KmeansVertexReader extends
GraphLoader<Text, Tuple, NullWritable, NullWritable> {
@Override
public void load(LongWritable recordNum, WritableRecord record,
MutationContext<Text, Tuple, NullWritable, NullWritable> context)
throws IOException {
KmeansVertex vertex = new KmeansVertex();
vertex.setId(new Text(String.valueOf(recordNum.get())));
vertex.setValue(new Tuple(record.getAll()));
context.addVertexRequest(vertex);
public static class KmeansAggrValue implements Writable {
Tuple centers = new Tuple();
Tuple sums = new Tuple();
Tuple counts = new Tuple();
@ Override
public void write(DataOutput out) throws IOException {
centers.write(out);
sums.write(out);
counts.write(out);
@Override
public void readFields(DataInput in) throws IOException {
centers = new Tuple();
centers.readFields(in);
sums = new Tuple();
sums.readFields(in);
counts = new Tuple();
counts.readFields(in);
@Override
public String toString(){
return "centers " + centers.toString() + ", sums " + sums.toString()
+ ", counts " + counts.toString();
public static class KmeansAggregator extends Aggregator<KmeansAggrValue> {
@SuppressWarnings("rawtypes")
@Override
public KmeansAggrValue createInitialValue(WorkerContext context)
throws IOException {
KmeansAggrValue aggrVal = null;
if (context.getSuperstep() == 0) {
aggrVal = new KmeansAggrValue();
aggrVal.centers = new Tuple();
aggrVal.sums = new Tuple();
aggrVal.counts = new Tuple();
byte[] centers = context.readCacheFile("centers");
String lines[] = new String(centers).split("\n");
for (int i = 0; i < lines.length; i++) {
String[] ss = lines[i].split(",");
Tuple center = new Tuple();
Tuple sum = new Tuple();
for (int j = 0; j < ss.length; ++j) {
center.append(new DoubleWritable(Double.valueOf(ss[j].trim())));
sum.append(new DoubleWritable(0.0));
LongWritable count = new LongWritable(0);
aggrVal.sums.append(sum);
aggrVal.counts.append(count);
aggrVal.centers.append(center);
} else{
aggrVal = (KmeansAggrValue) context.getLastAggregatedValue(0);
return aggrVal;
@Override
Public void aggregate (glasvalue, object item ){
int min = 0;
double mindist = Double.MAX_VALUE;
Tuple point = (Tuple) item;
for (int i = 0; i < value.centers.size(); i++) {
Tuple center = (Tuple) value.centers.get(i);
// use Euclidean Distance, no need to calculate sqrt
double dist = 0.0d;
for (int j = 0; j < center.size(); j++) {
double v = ((DoubleWritable) point.get(j)).get()
- ((DoubleWritable) center.get(j)).get();
dist += v * v;
if (dist < mindist) {
mindist = dist;
min = i;
// update sum and count
Tuple sum = (Tuple) value.sums.get(min);
for (int i = 0; i < point.size(); i++) {
DoubleWritable s = (DoubleWritable) sum.get(i);
s.set(s.get() + ((DoubleWritable) point.get(i)).get());
LongWritable count = (LongWritable) value.counts.get(min);
count.set(count.get() + 1);
@Override
public void merge(KmeansAggrValue value, KmeansAggrValue partial) {
for (int i = 0; i < value.sums.size(); i++) {
Tuple sum = (Tuple) value.sums.get(i);
Tuple that = (Tuple) partial.sums.get(i);
for (int j = 0; j < sum.size(); j++) {
DoubleWritable s = (DoubleWritable) sum.get(j);
s.set(s.get() + ((DoubleWritable) that.get(j)).get());
for (int i = 0; i < value.counts.size(); i++) {
LongWritable count = (LongWritable) value.counts.get(i);
count.set(count.get() + ((LongWritable) partial.counts.get(i)).get());
@SuppressWarnings("rawtypes")
@Override
public boolean terminate(WorkerContext context, KmeansAggrValue value)
throws IOException {
// compute new centers
Tuple newCenters = new Tuple(value.sums.size());
for (int i = 0; i < value.sums.size(); i++) {
Tuple sum = (Tuple) value.sums.get(i);
Tuple newCenter = new Tuple(sum.size());
LongWritable c = (LongWritable) value.counts.get(i);
for (int j = 0; j < sum.size(); j++) {
DoubleWritable s = (DoubleWritable) sum.get(j);
double val = s.get() / c.get();
newCenter.set(j, new DoubleWritable(val));
// reset sum for next iteration
s.set(0.0d);
// reset count for next iteration
c.set(0);
newCenters.set(i, newCenter);
// update centers
Tuple oldCenters = value.centers;
value.centers = newCenters;
LOG.info("old centers: " + oldCenters + ", new centers: " + newCenters);
// compare new/old centers
boolean converged = true;
for (int i = 0; i < value.centers.size() && converged; i++) {
Tuple oldCenter = (Tuple) oldCenters.get(i);
Tuple newCenter = (Tuple) newCenters.get(i);
double sum = 0.0d;
for (int j = 0; j < newCenter.size(); j++) {
double v = ((DoubleWritable) newCenter.get(j)).get()
- ((DoubleWritable) oldCenter.get(j)).get();
sum += v * v;
double dist = Math.sqrt(sum);
LOG.info("old center: " + oldCenter + ", new center: " + newCenter
+ ", dist: " + dist);
// converge threshold for each center: 0.05
converged = dist < 0.05d;
if (converged || context.getSuperstep() == context.getMaxIteration() - 1) {
// converged or reach max iteration, output centers
for (int i = 0; i < value.centers.size(); i++) {
context.write(((Tuple) value.centers.get(i)).toArray());
// true means to terminate iteration
return true;
// false means to continue iteration
return false;
private static void printUsage() {
System. out. println ("Usage: <in> <out> [Max iterations (default 30)] ");
System.exit(-1);
public static void main(final String [] args)throws IOException{
if (args.length < 2)
printUsage();
GraphJob job = new GraphJob();
job.setGraphLoaderClass(KmeansVertexReader.class);
job.setRuntimePartitioning(false);
job.setVertexClass(KmeansVertex.class);
job.setAggregatorClass(KmeansAggregator.class);
job.addInput(TableInfo.builder().tableName(args[0]).build());
job.addOutput(TableInfo.builder().tableName(args[1]).build());
// default max iteration is 30
job.setMaxIteration(30);
if (args.length >= 3)
job.setMaxIteration(Integer.parseInt(args[2]));
long start = System.currentTimeMillis();
job.run();
System.out.println("Job Finished in "
+ (System.currentTimeMillis() - start) / 1000.0 + " seconds");
以下は、K-means のソースコードについての説明です。
- 26 行目: KmeansVertex を定義します。 compute () メソッドの実装はシンプルです。 コンテキストオブジェクトの aggregate () メソッドを呼び出します。 次に、現在の頂点の値 (タプル 型で、ベクトルで表現される) を送信します。
- 38 行目: KmeansVertexReader クラスを定義し、グラフを読み込み、テーブル内の各レコードを頂点と見なします。 頂点 ID は関係ありません。送信された recordNum は ID として使用されます。 頂点の値は、レコードのすべての列から構成されるタプルです。
- 83 行目: KmeansAggregator を定義します。 このクラスは、K-means アルゴリズムの主なロジックをカプセル化します。ここで、
- createInitialValue は、反復ごとに初期値を作成します (k クラス 中心点)。 最初の反復 (superstep が 0) では、値は中心点の初期値です。 それ以外の場合は、値は最後の反復が終了したときの新しい中心点です。
- aggregate () メソッドは、各頂点から異なるクラスの中心までの距離を計算し、最も近い中心のクラスとしてその頂点を分類し、そのクラスの合計とカウントを更新します。
- merge () メソッドは、各 Worker によって収集された合計とカウントを組み合わせます。
- terminate () メソッドは、各クラスの合計とカウントに基づいて新しい中心点を計算します。 新しい中心点と古い中心点の間の距離がしきい値より小さいか、または反復回数が上限値に達すると、反復は終了します (false が返されます)。 最終的な中心点が結果テーブルに書き込まれます。
- 236 行目: メインプログラム (main 関数) を実行し、GraphJob を定義し、そして Vertex/GraphLoader/Aggregator の実装を指定します。 最大反復数 (デフォルトは 30)、および入力テーブルと出力テーブル。
- 243 行目: job.setRuntimePartitioning (false) を指定します。 K-means アルゴリズムでは、グラフの読み込み中に頂点を分散させる必要はありません。 RuntimePartitioning が false に設定されると、グラフ読み込みのパフォーマンスが向上します。