MATH-1516: Interface for ranking a list of clusters.
"MultiKMeansPlusPlusClusterer" updated to use the interface.
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/ClusterRanking.java b/src/main/java/org/apache/commons/math4/ml/clustering/ClusterRanking.java
new file mode 100644
index 0000000..a6c87d7
--- /dev/null
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/ClusterRanking.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math4.ml.clustering;
+
+import java.util.List;
+
+/**
+ * Evaluates the quality of a set of clusters.
+ * It is assumed that
+ * <ul>
+ * <li>rank is positive,</li>
+ * <li>higher rank means better clustering.</li>
+ * </ul>
+ */
+@FunctionalInterface
+public interface ClusterRanking<T extends Clusterable> {
+ /**
+ * Computes the rank (higher is better).
+ *
+ * @param clusters Clusters to be evaluated.
+ * @return the rank of the provided {@code clusters}.
+ */
+ double compute(List<? extends Cluster<T>> clusters);
+}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/MultiKMeansPlusPlusClusterer.java b/src/main/java/org/apache/commons/math4/ml/clustering/MultiKMeansPlusPlusClusterer.java
index 61606c8..e1af3af 100644
--- a/src/main/java/org/apache/commons/math4/ml/clustering/MultiKMeansPlusPlusClusterer.java
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/MultiKMeansPlusPlusClusterer.java
@@ -40,7 +40,7 @@
private final int numTrials;
/** The cluster evaluator to use. */
- private final ClusterEvaluator<T> evaluator;
+ private final ClusterRanking<T> evaluator;
/** Build a clusterer.
* @param clusterer the k-means clusterer to use
@@ -59,7 +59,7 @@
*/
public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
final int numTrials,
- final ClusterEvaluator<T> evaluator) {
+ final ClusterRanking<T> evaluator) {
super(clusterer.getDistanceMeasure());
this.clusterer = clusterer;
this.numTrials = numTrials;
@@ -108,7 +108,7 @@
// at first, we have not found any clusters list yet
List<CentroidCluster<T>> best = null;
- double bestVarianceSum = Double.POSITIVE_INFINITY;
+ double bestRank = Double.NEGATIVE_INFINITY;
// do several clustering trials
for (int i = 0; i < numTrials; ++i) {
@@ -116,20 +116,17 @@
// compute a clusters list
List<CentroidCluster<T>> clusters = clusterer.cluster(points);
- // compute the variance of the current list
- final double varianceSum = evaluator.score(clusters);
+ // compute the rank of the current list
+ final double rank = evaluator.compute(clusters);
- if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) {
+ if (rank > bestRank) {
// this one is the best we have found so far, remember it
- best = clusters;
- bestVarianceSum = varianceSum;
+ best = clusters;
+ bestRank = rank;
}
-
}
// return the best clusters list found
return best;
-
}
-
}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariances.java b/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariances.java
index 0e3d0e5..14a82b7 100644
--- a/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariances.java
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/evaluation/SumOfClusterVariances.java
@@ -21,6 +21,7 @@
import org.apache.commons.math4.ml.clustering.Cluster;
import org.apache.commons.math4.ml.clustering.Clusterable;
+import org.apache.commons.math4.ml.clustering.ClusterRanking;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.stat.descriptive.moment.Variance;
@@ -35,7 +36,8 @@
* @param <T> the type of the clustered points
* @since 3.3
*/
-public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluator<T> {
+public class SumOfClusterVariances<T extends Clusterable> extends ClusterEvaluator<T>
+ implements ClusterRanking<T> {
/**
*
@@ -66,4 +68,9 @@
return varianceSum;
}
+ /** {@inheritDoc} */
+ @Override
+ public double compute(List<? extends Cluster<T>> clusters) {
+ return 1d / score(clusters);
+ }
}