| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.commons.math4.ml.clustering; |
| |
| import org.apache.commons.math4.exception.NumberIsTooSmallException; |
| import org.apache.commons.math4.ml.distance.DistanceMeasure; |
| import org.apache.commons.math4.util.MathUtils; |
| import org.apache.commons.math4.util.Pair; |
| import org.apache.commons.rng.UniformRandomProvider; |
| import org.apache.commons.rng.sampling.ListSampler; |
| |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.List; |
| |
| /** |
| * Clustering algorithm <a href="https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf"> |
| * based on KMeans</a>. |
| * |
| * @param <T> Type of the points to cluster. |
| */ |
| public class MiniBatchKMeansClusterer<T extends Clusterable> |
| extends KMeansPlusPlusClusterer<T> { |
| /** Batch data size in iteration. */ |
| private final int batchSize; |
| /** Iteration count of initialize the centers. */ |
| private final int initIterations; |
| /** Data size of batch to initialize the centers. */ |
| private final int initBatchSize; |
| /** Maximum number of iterations during which no improvement is occuring. */ |
| private final int maxNoImprovementTimes; |
| |
| |
| /** |
| * Build a clusterer. |
| * |
| * @param k Number of clusters to split the data into. |
| * @param maxIterations Maximum number of iterations to run the algorithm for all the points, |
| * The actual number of iterationswill be smaller than {@code maxIterations * size / batchSize}, |
| * where {@code size} is the number of points to cluster. |
| * Disabled if negative. |
| * @param batchSize Batch size for training iterations. |
| * @param initIterations Number of iterations allowed in order to find out the best initial centers. |
| * @param initBatchSize Batch size for initializing the clusters centers. |
| * A value of {@code 3 * batchSize} should be suitable in most cases. |
| * @param maxNoImprovementTimes Maximum number of iterations during which no improvement is occuring. |
| * A value of 10 is suitable in most cases. |
| * @param measure Distance measure. |
| * @param random Random generator. |
| * @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations. |
| */ |
| public MiniBatchKMeansClusterer(final int k, |
| final int maxIterations, |
| final int batchSize, |
| final int initIterations, |
| final int initBatchSize, |
| final int maxNoImprovementTimes, |
| final DistanceMeasure measure, |
| final UniformRandomProvider random, |
| final EmptyClusterStrategy emptyStrategy) { |
| super(k, maxIterations, measure, random, emptyStrategy); |
| |
| if (batchSize < 1) { |
| throw new NumberIsTooSmallException(batchSize, 1, true); |
| } |
| if (initIterations < 1) { |
| throw new NumberIsTooSmallException(initIterations, 1, true); |
| } |
| if (initBatchSize < 1) { |
| throw new NumberIsTooSmallException(initBatchSize, 1, true); |
| } |
| if (maxNoImprovementTimes < 1) { |
| throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true); |
| } |
| |
| this.batchSize = batchSize; |
| this.initIterations = initIterations; |
| this.initBatchSize = initBatchSize; |
| this.maxNoImprovementTimes = maxNoImprovementTimes; |
| } |
| |
| /** |
| * Runs the MiniBatch K-means clustering algorithm. |
| * |
| * @param points Points to cluster (cannot be {@code null}). |
| * @return the clusters. |
| * @throws org.apache.commons.math4.exception.MathIllegalArgumentException |
| * if the number of points is smaller than the number of clusters. |
| */ |
| @Override |
| public List<CentroidCluster<T>> cluster(final Collection<T> points) { |
| // Sanity check. |
| MathUtils.checkNotNull(points); |
| if (points.size() < getNumberOfClusters()) { |
| throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false); |
| } |
| |
| final int pointSize = points.size(); |
| final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0); |
| final int max = getMaxIterations() < 0 ? |
| Integer.MAX_VALUE : |
| getMaxIterations() * batchCount; |
| |
| final List<T> pointList = new ArrayList<>(points); |
| List<CentroidCluster<T>> clusters = initialCenters(pointList); |
| |
| final ImprovementEvaluator evaluator = new ImprovementEvaluator(batchSize, |
| maxNoImprovementTimes); |
| for (int i = 0; i < max; i++) { |
| clearClustersPoints(clusters); |
| final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize); |
| // Training step. |
| final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters); |
| final double squareDistance = pair.getFirst(); |
| clusters = pair.getSecond(); |
| // Check whether the training can finished early. |
| if (evaluator.converge(squareDistance, pointSize)) { |
| break; |
| } |
| } |
| |
| // Add every mini batch points to their nearest cluster. |
| clearClustersPoints(clusters); |
| for (final T point : points) { |
| addToNearestCentroidCluster(point, clusters); |
| } |
| |
| return clusters; |
| } |
| |
| /** |
| * Helper method. |
| * |
| * @param clusters Clusters to clear. |
| */ |
| private void clearClustersPoints(final List<CentroidCluster<T>> clusters) { |
| for (CentroidCluster<T> cluster : clusters) { |
| cluster.getPoints().clear(); |
| } |
| } |
| |
| /** |
| * Mini batch iteration step. |
| * |
| * @param batchPoints Points selected for this batch. |
| * @param clusters Centers of the clusters. |
| * @return the squared distance of all the batch points to the nearest center. |
| */ |
| private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints, |
| final List<CentroidCluster<T>> clusters) { |
| // Add every mini batch points to their nearest cluster. |
| for (final T point : batchPoints) { |
| addToNearestCentroidCluster(point, clusters); |
| } |
| final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters); |
| // Add every mini batch points to their nearest cluster again. |
| double squareDistance = 0.0; |
| for (T point : batchPoints) { |
| final double d = addToNearestCentroidCluster(point, newClusters); |
| squareDistance += d * d; |
| } |
| |
| return new Pair<>(squareDistance, newClusters); |
| } |
| |
| /** |
| * Initializes the clusters centers. |
| * |
| * @param points Points used to initialize the centers. |
| * @return clusters with their center initialized. |
| */ |
| private List<CentroidCluster<T>> initialCenters(final List<T> points) { |
| final List<T> validPoints = initBatchSize < points.size() ? |
| ListSampler.sample(getRandomGenerator(), points, initBatchSize) : |
| new ArrayList<>(points); |
| double nearestSquareDistance = Double.POSITIVE_INFINITY; |
| List<CentroidCluster<T>> bestCenters = null; |
| |
| for (int i = 0; i < initIterations; i++) { |
| final List<T> initialPoints = (initBatchSize < points.size()) ? |
| ListSampler.sample(getRandomGenerator(), points, initBatchSize) : |
| new ArrayList<>(points); |
| final List<CentroidCluster<T>> clusters = chooseInitialCenters(initialPoints); |
| final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters); |
| final double squareDistance = pair.getFirst(); |
| final List<CentroidCluster<T>> newClusters = pair.getSecond(); |
| //Find out a best centers that has the nearest total square distance. |
| if (squareDistance < nearestSquareDistance) { |
| nearestSquareDistance = squareDistance; |
| bestCenters = newClusters; |
| } |
| } |
| return bestCenters; |
| } |
| |
| /** |
| * Adds a point to the cluster whose center is closest. |
| * |
| * @param point Point to add. |
| * @param clusters Clusters. |
| * @return the distance between point and the closest center. |
| */ |
| private double addToNearestCentroidCluster(final T point, |
| final List<CentroidCluster<T>> clusters) { |
| double minDistance = Double.POSITIVE_INFINITY; |
| CentroidCluster<T> closestCentroidCluster = null; |
| |
| // Find cluster closest to the point. |
| for (CentroidCluster<T> centroidCluster : clusters) { |
| final double distance = distance(point, centroidCluster.getCenter()); |
| if (distance < minDistance) { |
| minDistance = distance; |
| closestCentroidCluster = centroidCluster; |
| } |
| } |
| MathUtils.checkNotNull(closestCentroidCluster); |
| closestCentroidCluster.addPoint(point); |
| |
| return minDistance; |
| } |
| |
| /** |
| * Stopping criterion. |
| * The evaluator checks whether improvement occurred during the |
| * {@link #maxNoImprovementTimes allowed number of successive iterations}. |
| */ |
| private static class ImprovementEvaluator { |
| /** Batch size. */ |
| private final int batchSize; |
| /** Maximum number of iterations during which no improvement is occuring. */ |
| private final int maxNoImprovementTimes; |
| /** |
| * <a href="https://en.wikipedia.org/wiki/Moving_average"> |
| * Exponentially Weighted Average</a> of the squared |
| * diff to monitor the convergence while discarding |
| * minibatch-local stochastic variability. |
| */ |
| private double ewaInertia = Double.NaN; |
| /** Minimum value of {@link #ewaInertia} during iteration. */ |
| private double ewaInertiaMin = Double.POSITIVE_INFINITY; |
| /** Number of iteration during which {@link #ewaInertia} did not improve. */ |
| private int noImprovementTimes = 0; |
| |
| /** |
| * @param batchSize Number of elements for each batch iteration. |
| * @param maxNoImprovementTimes Maximum number of iterations during |
| * which no improvement is occuring. |
| */ |
| private ImprovementEvaluator(int batchSize, |
| int maxNoImprovementTimes) { |
| this.batchSize = batchSize; |
| this.maxNoImprovementTimes = maxNoImprovementTimes; |
| } |
| |
| /** |
| * Stopping criterion. |
| * |
| * @param squareDistance Total square distance from the batch points |
| * to their nearest center. |
| * @param pointSize Number of data points. |
| * @return {@code true} if no improvement was made after the allowed |
| * number of iterations, {@code false} otherwise. |
| */ |
| public boolean converge(final double squareDistance, |
| final int pointSize) { |
| final double batchInertia = squareDistance / batchSize; |
| if (Double.isNaN(ewaInertia)) { |
| ewaInertia = batchInertia; |
| } else { |
| final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1); |
| ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha; |
| } |
| |
| if (ewaInertia < ewaInertiaMin) { |
| // Improved. |
| noImprovementTimes = 0; |
| ewaInertiaMin = ewaInertia; |
| } else { |
| // No improvement. |
| ++noImprovementTimes; |
| } |
| |
| return noImprovementTimes >= maxNoImprovementTimes; |
| } |
| } |
| } |