001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math.stat.clustering;
019
020 import java.util.ArrayList;
021 import java.util.Collection;
022 import java.util.List;
023 import java.util.Random;
024
025 /**
026 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
027 * @param <T> type of the points to cluster
028 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
029 * @version $Revision: 811685 $ $Date: 2009-09-05 13:36:48 -0400 (Sat, 05 Sep 2009) $
030 * @since 2.0
031 */
032 public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
033
034 /** Random generator for choosing initial centers. */
035 private final Random random;
036
037 /** Build a clusterer.
038 * @param random random generator to use for choosing initial centers
039 */
040 public KMeansPlusPlusClusterer(final Random random) {
041 this.random = random;
042 }
043
044 /**
045 * Runs the K-means++ clustering algorithm.
046 *
047 * @param points the points to cluster
048 * @param k the number of clusters to split the data into
049 * @param maxIterations the maximum number of iterations to run the algorithm
050 * for. If negative, no maximum will be used
051 * @return a list of clusters containing the points
052 */
053 public List<Cluster<T>> cluster(final Collection<T> points,
054 final int k, final int maxIterations) {
055 // create the initial clusters
056 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
057 assignPointsToClusters(clusters, points);
058
059 // iterate through updating the centers until we're done
060 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
061 for (int count = 0; count < max; count++) {
062 boolean clusteringChanged = false;
063 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
064 for (final Cluster<T> cluster : clusters) {
065 final T newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
066 if (!newCenter.equals(cluster.getCenter())) {
067 clusteringChanged = true;
068 }
069 newClusters.add(new Cluster<T>(newCenter));
070 }
071 if (!clusteringChanged) {
072 return clusters;
073 }
074 assignPointsToClusters(newClusters, points);
075 clusters = newClusters;
076 }
077 return clusters;
078 }
079
080 /**
081 * Adds the given points to the closest {@link Cluster}.
082 *
083 * @param <T> type of the points to cluster
084 * @param clusters the {@link Cluster}s to add the points to
085 * @param points the points to add to the given {@link Cluster}s
086 */
087 private static <T extends Clusterable<T>> void
088 assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
089 for (final T p : points) {
090 Cluster<T> cluster = getNearestCluster(clusters, p);
091 cluster.addPoint(p);
092 }
093 }
094
095 /**
096 * Use K-means++ to choose the initial centers.
097 *
098 * @param <T> type of the points to cluster
099 * @param points the points to choose the initial centers from
100 * @param k the number of centers to choose
101 * @param random random generator to use
102 * @return the initial centers
103 */
104 private static <T extends Clusterable<T>> List<Cluster<T>>
105 chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
106
107 final List<T> pointSet = new ArrayList<T>(points);
108 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
109
110 // Choose one center uniformly at random from among the data points.
111 final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
112 resultSet.add(new Cluster<T>(firstPoint));
113
114 final double[] dx2 = new double[pointSet.size()];
115 while (resultSet.size() < k) {
116 // For each data point x, compute D(x), the distance between x and
117 // the nearest center that has already been chosen.
118 int sum = 0;
119 for (int i = 0; i < pointSet.size(); i++) {
120 final T p = pointSet.get(i);
121 final Cluster<T> nearest = getNearestCluster(resultSet, p);
122 final double d = p.distanceFrom(nearest.getCenter());
123 sum += d * d;
124 dx2[i] = sum;
125 }
126
127 // Add one new data point as a center. Each point x is chosen with
128 // probability proportional to D(x)2
129 final double r = random.nextDouble() * sum;
130 for (int i = 0 ; i < dx2.length; i++) {
131 if (dx2[i] >= r) {
132 final T p = pointSet.remove(i);
133 resultSet.add(new Cluster<T>(p));
134 break;
135 }
136 }
137 }
138
139 return resultSet;
140
141 }
142
143 /**
144 * Returns the nearest {@link Cluster} to the given point
145 *
146 * @param <T> type of the points to cluster
147 * @param clusters the {@link Cluster}s to search
148 * @param point the point to find the nearest {@link Cluster} for
149 * @return the nearest {@link Cluster} to the given point
150 */
151 private static <T extends Clusterable<T>> Cluster<T>
152 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
153 double minDistance = Double.MAX_VALUE;
154 Cluster<T> minCluster = null;
155 for (final Cluster<T> c : clusters) {
156 final double distance = point.distanceFrom(c.getCenter());
157 if (distance < minDistance) {
158 minDistance = distance;
159 minCluster = c;
160 }
161 }
162 return minCluster;
163 }
164
165 }