blob: ca6c7d1a6490a710b4cdeed4f2d2c657c0822aee [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.ml.clustering;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.ml.clustering.evaluation.CalinskiHarabasz;
import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class MiniBatchKMeansClustererTest {
/**
* Assert the illegal parameter throws proper Exceptions.
*/
@Test
public void testConstructorParameterChecks() {
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, -1, 3, 300, 10, null, null, null));
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, -2, 300, 10, null, null, null));
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, 3, -300, 10, null, null, null));
expectNumberIsTooSmallException(() -> new MiniBatchKMeansClusterer<>(1, -1, 100, 3, 300, -10, null, null, null));
}
/**
* Expects block throws NumberIsTooSmallException.
* @param block the block need to run.
*/
private void expectNumberIsTooSmallException(Runnable block) {
assertException(block, NumberIsTooSmallException.class);
}
/**
* Compare the result to KMeansPlusPlusClusterer
*/
@Test
public void testCompareToKMeans() {
//Generate 4 cluster
final UniformRandomProvider rng = RandomSource.MT_64.create();
List<DoublePoint> data = generateCircles(rng);
KMeansPlusPlusClusterer<DoublePoint> kMeans =
new KMeansPlusPlusClusterer<>(4, -1, DEFAULT_MEASURE, rng);
MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans =
new MiniBatchKMeansClusterer<>(4, -1, 100, 3, 300, 10, DEFAULT_MEASURE, rng,
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
// Test 100 times between KMeansPlusPlusClusterer and MiniBatchKMeansClusterer
for (int i = 0; i < 100; i++) {
List<CentroidCluster<DoublePoint>> kMeansClusters = kMeans.cluster(data);
List<CentroidCluster<DoublePoint>> miniBatchKMeansClusters = miniBatchKMeans.cluster(data);
// Assert cluster result has proper clusters count.
Assert.assertEquals(4, kMeansClusters.size());
Assert.assertEquals(kMeansClusters.size(), miniBatchKMeansClusters.size());
int totalDiffCount = 0;
for (CentroidCluster<DoublePoint> kMeanCluster : kMeansClusters) {
// Find out most similar cluster between two clusters, and summary the points count variances.
CentroidCluster<DoublePoint> miniBatchCluster = predict(miniBatchKMeansClusters, kMeanCluster.getCenter());
totalDiffCount += Math.abs(kMeanCluster.getPoints().size() - miniBatchCluster.getPoints().size());
}
// Statistic points different ratio.
double diffPointsRatio = totalDiffCount * 1.0 / data.size();
// Evaluator score different ratio by "CalinskiHarabasz" algorithm.
ClusterEvaluator clusterEvaluator = new CalinskiHarabasz();
double kMeansScore = clusterEvaluator.score(kMeansClusters);
double miniBatchKMeansScore = clusterEvaluator.score(miniBatchKMeansClusters);
double scoreDiffRatio = (kMeansScore - miniBatchKMeansScore) /
kMeansScore;
// MiniBatchKMeansClusterer has few score differences between KMeansClusterer(less then 10%).
Assert.assertTrue(String.format("Different score ratio %f%%!, diff points ratio: %f%%", scoreDiffRatio * 100, diffPointsRatio * 100),
scoreDiffRatio < 0.1);
}
}
/**
* Generate points around 4 circles.
* @param rng RNG.
* @return Generated points.
*/
private List<DoublePoint> generateCircles(UniformRandomProvider random) {
List<DoublePoint> data = new ArrayList<>();
data.addAll(generateCircle(250, new double[]{-1.0, -1.0}, 1.0, random));
data.addAll(generateCircle(260, new double[]{0.0, 0.0}, 0.7, random));
data.addAll(generateCircle(270, new double[]{1.0, 1.0}, 0.7, random));
data.addAll(generateCircle(280, new double[]{2.0, 2.0}, 0.7, random));
return data;
}
/**
* Generate points as circles.
* @param count total points count.
* @param center circle center point.
* @param radius the circle radius points around.
* @param random the Random source.
* @return Generated points.
*/
List<DoublePoint> generateCircle(int count, double[] center, double radius,
UniformRandomProvider random) {
double x0 = center[0];
double y0 = center[1];
ArrayList<DoublePoint> list = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
double ao = random.nextDouble() * 720 - 360;
double r = random.nextDouble() * radius * 2 - radius;
double x1 = x0 + r * Math.cos(ao * Math.PI / 180);
double y1 = y0 + r * Math.sin(ao * Math.PI / 180);
list.add(new DoublePoint(new double[]{x1, y1}));
}
return list;
}
/**
* Assert there should be a exception.
*
* @param block The code block need to assert.
* @param exceptionClass A exception class.
*/
public static void assertException(Runnable block, Class<? extends Throwable> exceptionClass) {
try {
block.run();
Assert.fail(String.format("Expects %s", exceptionClass.getSimpleName()));
} catch (Throwable e) {
if (!exceptionClass.isInstance(e)) {
throw e;
}
}
}
/**
* Use EuclideanDistance as default DistanceMeasure
*/
public static final DistanceMeasure DEFAULT_MEASURE = new EuclideanDistance();
/**
* Predict which cluster is best for the point
*
* @param clusters cluster to predict into
* @param point point to predict
* @param measure distance measurer
* @param <T> type of cluster point
* @return the cluster which has nearest center to the point
*/
public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point, DistanceMeasure measure) {
double minDistance = Double.POSITIVE_INFINITY;
CentroidCluster<T> nearestCluster = null;
for (CentroidCluster<T> cluster : clusters) {
double distance = measure.compute(point.getPoint(), cluster.getCenter().getPoint());
if (distance < minDistance) {
minDistance = distance;
nearestCluster = cluster;
}
}
return nearestCluster;
}
/**
* Predict which cluster is best for the point
*
* @param clusters cluster to predict into
* @param point point to predict
* @param <T> type of cluster point
* @return the cluster which has nearest center to the point
*/
public static <T extends Clusterable> CentroidCluster<T> predict(List<CentroidCluster<T>> clusters, Clusterable point) {
return predict(clusters, point, DEFAULT_MEASURE);
}
}