Spark范例:K-means算法
- - yiihsia[互联网后端技术]_yiihsia[互联网后端技术]k-means 算法接受输入量 k ;然后将n个数据对象划分为 k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小. 聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的. 算法首先会随机确定K个中心位置(位于空间中代表聚类中心的点),然后将各个数据项分配给最临近的中心点.
k-means 算法接受输入量 k ;然后将n个数据对象划分为 k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的。
算法首先会随机确定K个中心位置(位于空间中代表聚类中心的点),然后将各个数据项分配给最临近的中心点。待分配完成之后,聚类中心就会移到分配给该聚类的所有节点的平均位置处,然后整个分配过程重新开始。这个过程会一直重复下去,直到分配过程不再产生变化为止。下图是涉及5个数据项和2个聚类的过程。
spark的例子里也提供了K-means算法的实现,我加了注释方便理解:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | package spark.examples import java.util.Random import spark.SparkContext import spark.SparkContext._ import spark.examples.Vector._ object SparkKMeans { /** * line -> vector */ def parseVector(line: String): Vector = { return new Vector(line.split(' ').map(_.toDouble)) } /** * 计算该节点的最近中心节点 */ def closestCenter(p: Vector, centers: Array[Vector]): Int = { var bestIndex = 0 var bestDist = p.squaredDist(centers(0))//差平方之和 for (i <- 1 until centers.length) { val dist = p.squaredDist(centers(i)) if (dist < bestDist) { bestDist = dist bestIndex = i } } return bestIndex } def main(args: Array[String]) { if (args.length < 3) { System.err.println("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>") System.exit(1) } val sc = new SparkContext(args(0), "SparkKMeans") val lines = sc.textFile(args(1), args(5).toInt) val points = lines.map(parseVector(_)).cache() //文本中每行为一个节点,再将每个节点转换成Vector val dimensions = args(2).toInt//节点的维度 val k = args(3).toInt //聚类个数 val iterations = args(4).toInt//迭代次数 // 随机初始化k个中心节点 val rand = new Random(42) var centers = new Array[Vector](k) for (i <- 0 until k) centers(i) = Vector(dimensions, _ => 2 * rand.nextDouble - 1) println("Initial centers: " + centers.mkString(", ")) val time1 = System.currentTimeMillis() for (i <- 1 to iterations) { println("On iteration " + i) // Map each point to the index of its closest center and a (point, 1) pair // that we will use to compute an average later val mappedPoints = points.map { p => (closestCenter(p, centers), (p, 1)) } val newCenters = mappedPoints.reduceByKey { case ((sum1, count1), (sum2, count2)) => (sum1 + sum2, count1 + count2) //(向量相加, 计数器相加) }.map { case (id, (sum, count)) => (id, sum / count)//根据前面的聚类,重新计算中心节点的位置 }.collect // 更新中心节点 for ((id, value) <- newCenters) { centers(id) = value } } val time2 = System.currentTimeMillis() println("Final centers: " + centers.mkString(", ") + ", time: "+(time2- time1) ) } } |
例子中使用了iterations来限制迭代次数,并不是一种好的方法。可以设置一个阀值,在更新中心节点前,判断新的节点和上一次计算的中心计算差平方之和是否已经到了阀值内,如果是,则不需要继续计算下去。
其中用到的Vector类 https://github.com/mesos/spark/blob/master/examples/src/main/scala/spark/examples/Vector.scala