blob: b2f2e0ce7285b95ab8124b9be08b0125b73d9a27 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.ml.clustering;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.ml.distance.EuclideanDistance;
import org.apache.commons.rng.simple.RandomSource;
import org.apache.commons.rng.UniformRandomProvider;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
public class KMeansPlusPlusClustererTest {
private UniformRandomProvider random;
@Before
public void setUp() {
random = RandomSource.create(RandomSource.MT_64, 1746432956321l);
}
/**
* JIRA: MATH-305
*
* Two points, one cluster, one iteration
*/
@Test
public void testPerformClusterAnalysisDegenerate() {
KMeansPlusPlusClusterer<DoublePoint> transformer =
new KMeansPlusPlusClusterer<>(1, 1);
DoublePoint[] points = new DoublePoint[] {
new DoublePoint(new int[] { 1959, 325100 }),
new DoublePoint(new int[] { 1960, 373200 }), };
List<? extends Cluster<DoublePoint>> clusters = transformer.cluster(Arrays.asList(points));
Assert.assertEquals(1, clusters.size());
Assert.assertEquals(2, (clusters.get(0).getPoints().size()));
DoublePoint pt1 = new DoublePoint(new int[] { 1959, 325100 });
DoublePoint pt2 = new DoublePoint(new int[] { 1960, 373200 });
Assert.assertTrue(clusters.get(0).getPoints().contains(pt1));
Assert.assertTrue(clusters.get(0).getPoints().contains(pt2));
}
@Test
public void testCertainSpace() {
KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = {
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE,
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER,
KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT
};
for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) {
int numberOfVariables = 27;
// initialise testvalues
int position1 = 1;
int position2 = position1 + numberOfVariables;
int position3 = position2 + numberOfVariables;
int position4 = position3 + numberOfVariables;
// testvalues will be multiplied
int multiplier = 1000000;
DoublePoint[] breakingPoints = new DoublePoint[numberOfVariables];
// define the space which will break the cluster algorithm
for (int i = 0; i < numberOfVariables; i++) {
int points[] = { position1, position2, position3, position4 };
// multiply the values
for (int j = 0; j < points.length; j++) {
points[j] *= multiplier;
}
DoublePoint DoublePoint = new DoublePoint(points);
breakingPoints[i] = DoublePoint;
position1 += numberOfVariables;
position2 += numberOfVariables;
position3 += numberOfVariables;
position4 += numberOfVariables;
}
for (int n = 2; n < 27; ++n) {
KMeansPlusPlusClusterer<DoublePoint> transformer =
new KMeansPlusPlusClusterer<>(n, 100, new EuclideanDistance(), random, strategy);
List<? extends Cluster<DoublePoint>> clusters =
transformer.cluster(Arrays.asList(breakingPoints));
Assert.assertEquals(n, clusters.size());
int sum = 0;
for (Cluster<DoublePoint> cluster : clusters) {
sum += cluster.getPoints().size();
}
Assert.assertEquals(numberOfVariables, sum);
}
}
}
/**
* A helper class for testSmallDistances(). This class is similar to DoublePoint, but
* it defines a different distanceFrom() method that tends to return distances less than 1.
*/
private class CloseDistance extends EuclideanDistance {
private static final long serialVersionUID = 1L;
@Override
public double compute(double[] a, double[] b) {
return super.compute(a, b) * 0.001;
}
}
/**
* Test points that are very close together. See issue MATH-546.
*/
@Test
public void testSmallDistances() {
// Create a bunch of CloseDoublePoints. Most are identical, but one is different by a
// small distance.
int[] repeatedArray = { 0 };
int[] uniqueArray = { 1 };
DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
DoublePoint uniquePoint = new DoublePoint(uniqueArray);
Collection<DoublePoint> points = new ArrayList<>();
final int NUM_REPEATED_POINTS = 10 * 1000;
for (int i = 0; i < NUM_REPEATED_POINTS; ++i) {
points.add(repeatedPoint);
}
points.add(uniquePoint);
// Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to simply choose initial
// cluster centers).
final int NUM_CLUSTERS = 2;
final int NUM_ITERATIONS = 0;
KMeansPlusPlusClusterer<DoublePoint> clusterer =
new KMeansPlusPlusClusterer<>(NUM_CLUSTERS, NUM_ITERATIONS,
new CloseDistance(), random);
List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
// Check that one of the chosen centers is the unique point.
boolean uniquePointIsCenter = false;
for (CentroidCluster<DoublePoint> cluster : clusters) {
if (cluster.getCenter().equals(uniquePoint)) {
uniquePointIsCenter = true;
}
}
Assert.assertTrue(uniquePointIsCenter);
}
/**
* 2 variables cannot be clustered into 3 clusters. See issue MATH-436.
*/
@Test(expected=NumberIsTooSmallException.class)
public void testPerformClusterAnalysisToManyClusters() {
KMeansPlusPlusClusterer<DoublePoint> transformer =
new KMeansPlusPlusClusterer<>(3, 1, new EuclideanDistance(), random);
DoublePoint[] points = new DoublePoint[] {
new DoublePoint(new int[] {
1959, 325100
}), new DoublePoint(new int[] {
1960, 373200
})
};
transformer.cluster(Arrays.asList(points));
}
}