| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.commons.math4.legacy.ml.clustering; |
| |
| import org.apache.commons.math4.legacy.exception.NullArgumentException; |
| import org.apache.commons.math4.legacy.exception.ConvergenceException; |
| import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException; |
| import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; |
| import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure; |
| import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance; |
| import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance; |
| import org.apache.commons.rng.UniformRandomProvider; |
| import org.apache.commons.rng.simple.RandomSource; |
| |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.Collections; |
| import java.util.List; |
| |
| /** |
| * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. |
| * @param <T> type of the points to cluster |
| * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> |
| * @since 3.2 |
| */ |
| public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> { |
| |
| /** Strategies to use for replacing an empty cluster. */ |
| public enum EmptyClusterStrategy { |
| |
| /** Split the cluster with largest distance variance. */ |
| LARGEST_VARIANCE, |
| |
| /** Split the cluster with largest number of points. */ |
| LARGEST_POINTS_NUMBER, |
| |
| /** Create a cluster around the point farthest from its centroid. */ |
| FARTHEST_POINT, |
| |
| /** Generate an error. */ |
| ERROR |
| |
| } |
| |
| /** The number of clusters. */ |
| private final int numberOfClusters; |
| |
| /** The maximum number of iterations. */ |
| private final int maxIterations; |
| |
| /** Random generator for choosing initial centers. */ |
| private final UniformRandomProvider random; |
| |
| /** Selected strategy for empty clusters. */ |
| private final EmptyClusterStrategy emptyStrategy; |
| |
| /** Build a clusterer. |
| * <p> |
| * The default strategy for handling empty clusters that may appear during |
| * algorithm iterations is to split the cluster with largest distance variance. |
| * <p> |
| * The euclidean distance will be used as default distance measure. |
| * |
| * @param k the number of clusters to split the data into |
| */ |
| public KMeansPlusPlusClusterer(final int k) { |
| this(k, -1); |
| } |
| |
| /** Build a clusterer. |
| * <p> |
| * The default strategy for handling empty clusters that may appear during |
| * algorithm iterations is to split the cluster with largest distance variance. |
| * <p> |
| * The euclidean distance will be used as default distance measure. |
| * |
| * @param k the number of clusters to split the data into |
| * @param maxIterations the maximum number of iterations to run the algorithm for. |
| * If negative, no maximum will be used. |
| */ |
| public KMeansPlusPlusClusterer(final int k, final int maxIterations) { |
| this(k, maxIterations, new EuclideanDistance()); |
| } |
| |
| /** Build a clusterer. |
| * <p> |
| * The default strategy for handling empty clusters that may appear during |
| * algorithm iterations is to split the cluster with largest distance variance. |
| * |
| * @param k the number of clusters to split the data into |
| * @param maxIterations the maximum number of iterations to run the algorithm for. |
| * If negative, no maximum will be used. |
| * @param measure the distance measure to use |
| */ |
| public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) { |
| this(k, maxIterations, measure, RandomSource.MT_64.create()); |
| } |
| |
| /** Build a clusterer. |
| * <p> |
| * The default strategy for handling empty clusters that may appear during |
| * algorithm iterations is to split the cluster with largest distance variance. |
| * |
| * @param k the number of clusters to split the data into |
| * @param maxIterations the maximum number of iterations to run the algorithm for. |
| * If negative, no maximum will be used. |
| * @param measure the distance measure to use |
| * @param random random generator to use for choosing initial centers |
| */ |
| public KMeansPlusPlusClusterer(final int k, final int maxIterations, |
| final DistanceMeasure measure, |
| final UniformRandomProvider random) { |
| this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE); |
| } |
| |
| /** Build a clusterer. |
| * |
| * @param k the number of clusters to split the data into |
| * @param maxIterations the maximum number of iterations to run the algorithm for. |
| * If negative, no maximum will be used. |
| * @param measure the distance measure to use |
| * @param random random generator to use for choosing initial centers |
| * @param emptyStrategy strategy to use for handling empty clusters that |
| * may appear during algorithm iterations |
| */ |
| public KMeansPlusPlusClusterer(final int k, final int maxIterations, |
| final DistanceMeasure measure, |
| final UniformRandomProvider random, |
| final EmptyClusterStrategy emptyStrategy) { |
| super(measure); |
| this.numberOfClusters = k; |
| this.maxIterations = maxIterations; |
| this.random = random; |
| this.emptyStrategy = emptyStrategy; |
| } |
| |
| /** |
| * Return the number of clusters this instance will use. |
| * @return the number of clusters |
| */ |
| public int getNumberOfClusters() { |
| return numberOfClusters; |
| } |
| |
| /** |
| * Returns the maximum number of iterations this instance will use. |
| * @return the maximum number of iterations, or -1 if no maximum is set |
| */ |
| public int getMaxIterations() { |
| return maxIterations; |
| } |
| |
| /** |
| * Runs the K-means++ clustering algorithm. |
| * |
| * @param points the points to cluster |
| * @return a list of clusters containing the points |
| * @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException |
| * if the data points are null or the number of clusters is larger than the |
| * number of data points |
| * @throws ConvergenceException if an empty cluster is encountered and the |
| * empty cluster strategy is set to {@link EmptyClusterStrategy#ERROR} |
| */ |
| @Override |
| public List<CentroidCluster<T>> cluster(final Collection<T> points) { |
| // sanity checks |
| NullArgumentException.check(points); |
| |
| // number of clusters has to be smaller or equal the number of data points |
| if (points.size() < numberOfClusters) { |
| throw new NumberIsTooSmallException(points.size(), numberOfClusters, false); |
| } |
| |
| // create the initial clusters |
| List<CentroidCluster<T>> clusters = chooseInitialCenters(points); |
| |
| // create an array containing the latest assignment of a point to a cluster |
| // no need to initialize the array, as it will be filled with the first assignment |
| int[] assignments = new int[points.size()]; |
| assignPointsToClusters(clusters, points, assignments); |
| |
| // iterate through updating the centers until we're done |
| final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; |
| for (int count = 0; count < max; count++) { |
| boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty()); |
| List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters); |
| int changes = assignPointsToClusters(newClusters, points, assignments); |
| clusters = newClusters; |
| |
| // if there were no more changes in the point-to-cluster assignment |
| // and there are no empty clusters left, return the current clusters |
| if (changes == 0 && !hasEmptyCluster) { |
| return clusters; |
| } |
| } |
| return clusters; |
| } |
| |
| /** |
| * @return the random generator |
| */ |
| UniformRandomProvider getRandomGenerator() { |
| return random; |
| } |
| |
| /** |
| * @return the {@link EmptyClusterStrategy} |
| */ |
| EmptyClusterStrategy getEmptyClusterStrategy() { |
| return emptyStrategy; |
| } |
| |
| /** |
| * Adjust the clusters's centers with means of points. |
| * @param clusters the origin clusters |
| * @return adjusted clusters with center points |
| */ |
| List<CentroidCluster<T>> adjustClustersCenters(List<CentroidCluster<T>> clusters) { |
| List<CentroidCluster<T>> newClusters = new ArrayList<>(); |
| for (final CentroidCluster<T> cluster : clusters) { |
| final Clusterable newCenter; |
| if (cluster.getPoints().isEmpty()) { |
| switch (emptyStrategy) { |
| case LARGEST_VARIANCE : |
| newCenter = getPointFromLargestVarianceCluster(clusters); |
| break; |
| case LARGEST_POINTS_NUMBER : |
| newCenter = getPointFromLargestNumberCluster(clusters); |
| break; |
| case FARTHEST_POINT : |
| newCenter = getFarthestPoint(clusters); |
| break; |
| default : |
| throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); |
| } |
| } else { |
| newCenter = cluster.centroid(); |
| } |
| newClusters.add(new CentroidCluster<>(newCenter)); |
| } |
| return newClusters; |
| } |
| |
| /** |
| * Adds the given points to the closest {@link Cluster}. |
| * |
| * @param clusters the {@link Cluster}s to add the points to |
| * @param points the points to add to the given {@link Cluster}s |
| * @param assignments points assignments to clusters |
| * @return the number of points assigned to different clusters as the iteration before |
| */ |
| private int assignPointsToClusters(final List<CentroidCluster<T>> clusters, |
| final Collection<T> points, |
| final int[] assignments) { |
| int assignedDifferently = 0; |
| int pointIndex = 0; |
| for (final T p : points) { |
| int clusterIndex = getNearestCluster(clusters, p); |
| if (clusterIndex != assignments[pointIndex]) { |
| assignedDifferently++; |
| } |
| |
| CentroidCluster<T> cluster = clusters.get(clusterIndex); |
| cluster.addPoint(p); |
| assignments[pointIndex++] = clusterIndex; |
| } |
| |
| return assignedDifferently; |
| } |
| |
| /** |
| * Use K-means++ to choose the initial centers. |
| * |
| * @param points the points to choose the initial centers from |
| * @return the initial centers |
| */ |
| List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) { |
| |
| // Convert to list for indexed access. Make it unmodifiable, since removal of items |
| // would screw up the logic of this method. |
| final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points)); |
| |
| // The number of points in the list. |
| final int numPoints = pointList.size(); |
| |
| // Set the corresponding element in this array to indicate when |
| // elements of pointList are no longer available. |
| final boolean[] taken = new boolean[numPoints]; |
| |
| // The resulting list of initial centers. |
| final List<CentroidCluster<T>> resultSet = new ArrayList<>(); |
| |
| // Choose one center uniformly at random from among the data points. |
| final int firstPointIndex = random.nextInt(numPoints); |
| |
| final T firstPoint = pointList.get(firstPointIndex); |
| |
| resultSet.add(new CentroidCluster<T>(firstPoint)); |
| |
| // Must mark it as taken |
| taken[firstPointIndex] = true; |
| |
| // To keep track of the minimum distance squared of elements of |
| // pointList to elements of resultSet. |
| final double[] minDistSquared = new double[numPoints]; |
| |
| // Initialize the elements. Since the only point in resultSet is firstPoint, |
| // this is very easy. |
| for (int i = 0; i < numPoints; i++) { |
| if (i != firstPointIndex) { // That point isn't considered |
| double d = distance(firstPoint, pointList.get(i)); |
| minDistSquared[i] = d*d; |
| } |
| } |
| |
| while (resultSet.size() < numberOfClusters) { |
| |
| // Sum up the squared distances for the points in pointList not |
| // already taken. |
| double distSqSum = 0.0; |
| |
| for (int i = 0; i < numPoints; i++) { |
| if (!taken[i]) { |
| distSqSum += minDistSquared[i]; |
| } |
| } |
| |
| // Add one new data point as a center. Each point x is chosen with |
| // probability proportional to D(x)2 |
| final double r = random.nextDouble() * distSqSum; |
| |
| // The index of the next point to be added to the resultSet. |
| int nextPointIndex = -1; |
| |
| // Sum through the squared min distances again, stopping when |
| // sum >= r. |
| double sum = 0.0; |
| for (int i = 0; i < numPoints; i++) { |
| if (!taken[i]) { |
| sum += minDistSquared[i]; |
| if (sum >= r) { |
| nextPointIndex = i; |
| break; |
| } |
| } |
| } |
| |
| // If it's not set to >= 0, the point wasn't found in the previous |
| // for loop, probably because distances are extremely small. Just pick |
| // the last available point. |
| if (nextPointIndex == -1) { |
| for (int i = numPoints - 1; i >= 0; i--) { |
| if (!taken[i]) { |
| nextPointIndex = i; |
| break; |
| } |
| } |
| } |
| |
| // We found one. |
| if (nextPointIndex >= 0) { |
| |
| final T p = pointList.get(nextPointIndex); |
| |
| resultSet.add(new CentroidCluster<T> (p)); |
| |
| // Mark it as taken. |
| taken[nextPointIndex] = true; |
| |
| if (resultSet.size() < numberOfClusters) { |
| // Now update elements of minDistSquared. We only have to compute |
| // the distance to the new center to do this. |
| for (int j = 0; j < numPoints; j++) { |
| // Only have to worry about the points still not taken. |
| if (!taken[j]) { |
| double d = distance(p, pointList.get(j)); |
| double d2 = d * d; |
| if (d2 < minDistSquared[j]) { |
| minDistSquared[j] = d2; |
| } |
| } |
| } |
| } |
| |
| } else { |
| // None found -- |
| // Break from the while loop to prevent |
| // an infinite loop. |
| break; |
| } |
| } |
| |
| return resultSet; |
| } |
| |
| /** |
| * Get a random point from the {@link Cluster} with the largest distance variance. |
| * |
| * @param clusters the {@link Cluster}s to search |
| * @return a random point from the selected cluster |
| * @throws ConvergenceException if clusters are all empty |
| */ |
| private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) { |
| double maxVariance = Double.NEGATIVE_INFINITY; |
| Cluster<T> selected = null; |
| for (final CentroidCluster<T> cluster : clusters) { |
| if (!cluster.getPoints().isEmpty()) { |
| |
| // compute the distance variance of the current cluster |
| final Clusterable center = cluster.getCenter(); |
| final Variance stat = new Variance(); |
| for (final T point : cluster.getPoints()) { |
| stat.increment(distance(point, center)); |
| } |
| final double variance = stat.getResult(); |
| |
| // select the cluster with the largest variance |
| if (variance > maxVariance) { |
| maxVariance = variance; |
| selected = cluster; |
| } |
| |
| } |
| } |
| |
| // did we find at least one non-empty cluster ? |
| if (selected == null) { |
| throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); |
| } |
| |
| // extract a random point from the cluster |
| final List<T> selectedPoints = selected.getPoints(); |
| return selectedPoints.remove(random.nextInt(selectedPoints.size())); |
| |
| } |
| |
| /** |
| * Get a random point from the {@link Cluster} with the largest number of points. |
| * |
| * @param clusters the {@link Cluster}s to search |
| * @return a random point from the selected cluster |
| * @throws ConvergenceException if clusters are all empty |
| */ |
| private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) { |
| int maxNumber = 0; |
| Cluster<T> selected = null; |
| for (final Cluster<T> cluster : clusters) { |
| |
| // get the number of points of the current cluster |
| final int number = cluster.getPoints().size(); |
| |
| // select the cluster with the largest number of points |
| if (number > maxNumber) { |
| maxNumber = number; |
| selected = cluster; |
| } |
| |
| } |
| |
| // did we find at least one non-empty cluster ? |
| if (selected == null) { |
| throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); |
| } |
| |
| // extract a random point from the cluster |
| final List<T> selectedPoints = selected.getPoints(); |
| return selectedPoints.remove(random.nextInt(selectedPoints.size())); |
| |
| } |
| |
| /** |
| * Get the point farthest to its cluster center. |
| * |
| * @param clusters the {@link Cluster}s to search |
| * @return point farthest to its cluster center |
| * @throws ConvergenceException if clusters are all empty |
| */ |
| private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) { |
| double maxDistance = Double.NEGATIVE_INFINITY; |
| Cluster<T> selectedCluster = null; |
| int selectedPoint = -1; |
| for (final CentroidCluster<T> cluster : clusters) { |
| |
| // get the farthest point |
| final Clusterable center = cluster.getCenter(); |
| final List<T> points = cluster.getPoints(); |
| for (int i = 0; i < points.size(); ++i) { |
| final double distance = distance(points.get(i), center); |
| if (distance > maxDistance) { |
| maxDistance = distance; |
| selectedCluster = cluster; |
| selectedPoint = i; |
| } |
| } |
| |
| } |
| |
| // did we find at least one non-empty cluster ? |
| if (selectedCluster == null) { |
| throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); |
| } |
| |
| return selectedCluster.getPoints().remove(selectedPoint); |
| |
| } |
| |
| /** |
| * Returns the nearest {@link Cluster} to the given point. |
| * |
| * @param clusters the {@link Cluster}s to search |
| * @param point the point to find the nearest {@link Cluster} for |
| * @return the index of the nearest {@link Cluster} to the given point |
| */ |
| private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) { |
| double minDistance = Double.MAX_VALUE; |
| int clusterIndex = 0; |
| int minCluster = 0; |
| for (final CentroidCluster<T> c : clusters) { |
| final double distance = distance(point, c.getCenter()); |
| if (distance < minDistance) { |
| minDistance = distance; |
| minCluster = clusterIndex; |
| } |
| clusterIndex++; |
| } |
| return minCluster; |
| } |
| } |