blob: 13c98a498ae63a2c3cd5d54379a627188e0964a4 [file] [log] [blame]
package org.apache.samoa.evaluation.measures;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.samoa.moa.cluster.Clustering;
import org.apache.samoa.moa.core.DataPoint;
import org.apache.samoa.moa.evaluation.MeasureCollection;
import org.apache.samoa.moa.evaluation.MembershipMatrix;
public class StatisticalCollection extends MeasureCollection {
private boolean debug = false;
@Override
protected String[] getNames() {
// String[] names = {"van Dongen","Rand statistic", "C Index"};
return new String[] { "van Dongen", "Rand statistic" };
}
@Override
protected boolean[] getDefaultEnabled() {
return new boolean[] { false, false };
}
@Override
public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points)
throws Exception {
MembershipMatrix mm = new MembershipMatrix(clustering, points);
int numClasses = mm.getNumClasses();
int numCluster = clustering.size() + 1;
int n = mm.getTotalEntries();
double dongenMaxFC = 0;
double dongenMaxSumFC = 0;
for (int i = 0; i < numCluster; i++) {
double max = 0;
for (int j = 0; j < numClasses; j++) {
if (mm.getClusterClassWeight(i, j) > max)
max = mm.getClusterClassWeight(i, j);
}
dongenMaxFC += max;
if (mm.getClusterSum(i) > dongenMaxSumFC)
dongenMaxSumFC = mm.getClusterSum(i);
}
double dongenMaxHC = 0;
double dongenMaxSumHC = 0;
for (int j = 0; j < numClasses; j++) {
double max = 0;
for (int i = 0; i < numCluster; i++) {
if (mm.getClusterClassWeight(i, j) > max)
max = mm.getClusterClassWeight(i, j);
}
dongenMaxHC += max;
if (mm.getClassSum(j) > dongenMaxSumHC)
dongenMaxSumHC = mm.getClassSum(j);
}
double dongen = (dongenMaxFC + dongenMaxHC) / (2 * n);
// normalized dongen
// double dongen = 1-(2*n - dongenMaxFC - dongenMaxHC)/(2*n - dongenMaxSumFC
// - dongenMaxSumHC);
if (debug)
System.out.println("Dongen HC:" + dongenMaxHC + " FC:" + dongenMaxFC + " Total:" + dongen + " n " + n);
addValue("van Dongen", dongen);
// Rand index
// http://www.cais.ntu.edu.sg/~qihe/menu4.html
double m1 = 0;
for (int j = 0; j < numClasses; j++) {
double v = mm.getClassSum(j);
m1 += v * (v - 1) / 2.0;
}
double m2 = 0;
for (int i = 0; i < numCluster; i++) {
double v = mm.getClusterSum(i);
m2 += v * (v - 1) / 2.0;
}
double m = 0;
for (int i = 0; i < numCluster; i++) {
for (int j = 0; j < numClasses; j++) {
double v = mm.getClusterClassWeight(i, j);
m += v * (v - 1) / 2.0;
}
}
double M = n * (n - 1) / 2.0;
double rand = (M - m1 - m2 + 2 * m) / M;
// normalized rand
// double rand = (m - m1*m2/M)/(m1/2.0 + m2/2.0 - m1*m2/M);
addValue("Rand statistic", rand);
// addValue("C Index",cindex(clustering, points));
}
public double cindex(Clustering clustering, ArrayList<DataPoint> points) {
int numClusters = clustering.size();
double withinClustersDistance = 0;
int numDistancesWithin = 0;
double numDistances = 0;
// double[] withinClusters = new double[numClusters];
double[] minWithinClusters = new double[numClusters];
double[] maxWithinClusters = new double[numClusters];
ArrayList<Integer>[] pointsInClusters = new ArrayList[numClusters];
for (int c = 0; c < numClusters; c++) {
pointsInClusters[c] = new ArrayList<>();
minWithinClusters[c] = Double.MAX_VALUE;
maxWithinClusters[c] = Double.MIN_VALUE;
}
for (int p = 0; p < points.size(); p++) {
for (int c = 0; c < clustering.size(); c++) {
if (clustering.get(c).getInclusionProbability(points.get(p)) > 0.8) {
pointsInClusters[c].add(p);
numDistances++;
}
}
}
// calc within cluster distances + min and max values
for (int c = 0; c < numClusters; c++) {
int numDistancesInC = 0;
ArrayList<Integer> pointsInC = pointsInClusters[c];
for (int p = 0; p < pointsInC.size(); p++) {
DataPoint point = points.get(pointsInC.get(p));
for (int p1 = p + 1; p1 < pointsInC.size(); p1++) {
numDistancesWithin++;
numDistancesInC++;
DataPoint point1 = points.get(pointsInC.get(p1));
double dist = point.getDistance(point1);
withinClustersDistance += dist;
if (minWithinClusters[c] > dist)
minWithinClusters[c] = dist;
if (maxWithinClusters[c] < dist)
maxWithinClusters[c] = dist;
}
}
}
double minWithin = Double.MAX_VALUE;
double maxWithin = Double.MIN_VALUE;
for (int c = 0; c < numClusters; c++) {
if (minWithinClusters[c] < minWithin)
minWithin = minWithinClusters[c];
if (maxWithinClusters[c] > maxWithin)
maxWithin = maxWithinClusters[c];
}
double cindex = 0;
if (numDistancesWithin != 0) {
double meanWithinClustersDistance = withinClustersDistance / numDistancesWithin;
cindex = (meanWithinClustersDistance - minWithin) / (maxWithin - minWithin);
}
if (debug) {
System.out.println("Min:" + Arrays.toString(minWithinClusters));
System.out.println("Max:" + Arrays.toString(maxWithinClusters));
System.out.println("totalWithin:" + numDistancesWithin);
}
return cindex;
}
}