import java.io.*;
import java.net.URI;
import java.util.Iterator;
import java.util.Random;
import java.util.Vector;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapred.*;
import org.apache.hadoop.util.GenericOptionsParser;
public class KMeans {
static enum Counter { CENTERS, CHANGE, ITERATIONS }
public static class Point
implements WritableComparable<Point> {
// Longs because this will store sum of many ints
public LongWritable x;
public LongWritable y;
public IntWritable num;
// For summation points
public Point() {
this.x =
new LongWritable(0);
this.y =
new LongWritable(0);
this.num =
new IntWritable(0);
}
public Point(
int x,
int y) {
this.x =
new LongWritable(x);
this.y =
new LongWritable(y);
this.num =
new IntWritable(1);
}
public Point(IntWritable x, IntWritable y) {
this.x =
new LongWritable(x.get());
this.y =
new LongWritable(y.get());
this.num =
new IntWritable(1);
}
public void add(Point that) {
x.set(x.get() + that.x.get());
y.set(y.get() + that.y.get());
num.set(num.get() + that.num.get());
}
public void norm() {
x.set(x.get() / num.get());
y.set(y.get() / num.get());
num.set(1);
}
public void write(DataOutput out)
throws IOException {
x.write(out);
y.write(out);
num.write(out);
}
public void readFields(DataInput in)
throws IOException {
x.readFields(in);
y.readFields(in);
num.readFields(in);
}
public long distance(Point that) {
long dx = that.x.get() - x.get();
long dy = that.y.get() - y.get();
return dx * dx + dy * dy;
}
public String toString() {
String ret = x.toString() + '\t' + y.toString();
if (num.get() != 1)
ret += '\t' + num.toString();
return ret;
}
public int compareTo(Point that) {
int ret = x.compareTo(that.x);
if (ret == 0)
ret = y.compareTo(that.y);
if (ret == 0)
ret = num.compareTo(that.num);
return ret;
}
}
public static class Map
extends MapReduceBase
implements Mapper<Text, Text, Point, Point>
{
private Vector<Point> centers;
private IOException error;
public void configure(JobConf conf) {
try {
Path paths[] = DistributedCache.getLocalCacheFiles(conf);
if (paths.length != 1)
throw new IOException("Need exactly 1 centers file");
FileSystem fs = FileSystem.getLocal(conf);
SequenceFile.Reader in =
new SequenceFile.Reader(fs, paths[0], conf);
centers =
new Vector<Point>();
IntWritable x =
new IntWritable();
IntWritable y =
new IntWritable();
while(in.next(x, y))
centers.add(
new Point(x, y));
in.close();
// Generate new points if we don't have enough.
int k = conf.getInt("k", 0);
Random rand =
new Random();
final int MAX = 1024*1024;
for (
int i = centers.size(); i < k; i++) {
x.set(rand.nextInt(MAX));
y.set(rand.nextInt(MAX));
centers.add(
new Point(x, y));
}
}
catch (IOException e) {
error = e;
}
}
public void map(Text xt, Text yt,
OutputCollector<Point, Point> output, Reporter reporter)
throws IOException
{
if (error !=
null)
throw error;
int x = Integer.valueOf(xt.toString());
int y = Integer.valueOf(yt.toString());
Point p =
new Point(x, y);
Point center =
null;
long distance = Long.MAX_VALUE;
for (Point c : centers) {
long d = c.distance(p);
if (d <= distance) {
distance = d;
center = c;
}
}
output.collect(center, p);
}
}
public static class Combine
extends MapReduceBase
implements Reducer<Point, Point, Point, Point>
{
public void reduce(Point center, Iterator<Point> points,
OutputCollector<Point, Point> output, Reporter reporter)
throws IOException
{
Point sum =
new Point();
while(points.hasNext()) {
sum.add(points.next());
}
output.collect(center, sum);
}
}
public static class Reduce
extends MapReduceBase
implements Reducer<Point, Point, IntWritable, IntWritable>
{
public void reduce(Point center, Iterator<Point> points,
OutputCollector<IntWritable, IntWritable> output,
Reporter reporter)
throws IOException
{
Point sum =
new Point();
while (points.hasNext()) {
sum.add(points.next());
}
sum.norm();
IntWritable x =
new IntWritable((
int) sum.x.get());
IntWritable y =
new IntWritable((
int) sum.y.get());
output.collect(x, y);
reporter.incrCounter(Counter.CHANGE, sum.distance(center));
reporter.incrCounter(Counter.CENTERS, 1);
}
}
public static void error(String msg) {
System.err.println(msg);
System.exit(1);
}
public static void initialCenters(
int k, JobConf conf, FileSystem fs,
Path in, Path out)
throws IOException
{
BufferedReader input =
new BufferedReader(
new InputStreamReader(fs.open(in)));
SequenceFile.Writer output =
new SequenceFile.Writer(
fs, conf, out, IntWritable.
class, IntWritable.
class);
IntWritable x =
new IntWritable();
IntWritable y =
new IntWritable();
for (
int i = 0; i < k; i++) {
String line = input.readLine();
if (line ==
null)
error("Not enough points for number of means");
String parts[] = line.split("\t");
if (parts.length != 2)
throw new IOException("Found a point without two parts");
x.set(Integer.valueOf(parts[0]));
y.set(Integer.valueOf(parts[1]));
output.append(x, y);
}
output.close();
input.close();
}
public static void main(String args[])
throws IOException {
JobConf conf =
new JobConf(KMeans.
class);
GenericOptionsParser opts =
new GenericOptionsParser(conf, args);
String paths[] = opts.getRemainingArgs();
FileSystem fs = FileSystem.get(conf);
if (paths.length < 3)
error("Usage:\n"
+ "\tKMeans <file to display>\n"
+ "\tKMeans <output> <k> <input file>
"
);
Path outdir =
new Path(paths[0]);
int k = Integer.valueOf(paths[1]);
Path firstin =
new Path(paths[2]);
if (k < 1 || k > 20)
error("Strange number of means: " + paths[1]);
if (fs.exists(outdir)) {
if (!fs.getFileStatus(outdir).isDir())
error("Output directory \"" + outdir.toString()
+ "\" exists and is not a directory.");
}
else {
fs.mkdirs(outdir);
}
// Input: text file, each line "x\ty"
conf.setInputFormat(KeyValueTextInputFormat.
class);
for (
int i = 2; i < paths.length; i++)
FileInputFormat.addInputPath(conf,
new Path(paths[i]));
conf.setInt("k", k);
// Map: (x,y) -> (centroid, point)
conf.setMapperClass(Map.
class);
conf.setMapOutputKeyClass(Point.
class);
conf.setMapOutputValueClass(Point.
class);
// Combine: (centroid, points) -> (centroid, weighted point)
conf.setCombinerClass(Combine.
class);
// Reduce: (centroid, weighted points) -> (x, y) new centroid
conf.setReducerClass(Reduce.
class);
conf.setOutputKeyClass(IntWritable.
class);
conf.setOutputValueClass(IntWritable.
class);
// Output
conf.setOutputFormat(SequenceFileOutputFormat.
class);
// Chose initial centers
Path centers =
new Path(outdir, "initial.seq");
initialCenters(k, conf, fs, firstin, centers);
// Iterate
long change = Long.MAX_VALUE;
URI cache[] =
new URI[1];
for (
int iter = 1; iter <= 1000 && change > 100 * k; iter++) {
Path jobdir =
new Path(outdir, Integer.toString(iter));
FileOutputFormat.setOutputPath(conf, jobdir);
conf.setJobName("k-Means " + iter);
conf.setJarByClass(KMeans.
class);
cache[0] = centers.toUri();
DistributedCache.setCacheFiles( cache, conf );
RunningJob result = JobClient.runJob(conf);
System.out.println("Iteration: " + iter);
change = result.getCounters().getCounter(Counter.CHANGE);
centers =
new Path(jobdir, "part-00000");
}
}
}
192.5.53.208