blob: d0e4bc90f52a175391187070fe73311544732124 [file] [log] [blame]
package org.apache.samoa.evaluation.measures;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import java.util.ArrayList;
import org.apache.samoa.moa.cluster.Cluster;
import org.apache.samoa.moa.cluster.Clustering;
import org.apache.samoa.moa.core.DataPoint;
import org.apache.samoa.moa.evaluation.MeasureCollection;
public class SilhouetteCoefficient extends MeasureCollection {
private static final double pointInclusionProbThreshold = 0.8;
public SilhouetteCoefficient() {
super();
}
@Override
protected boolean[] getDefaultEnabled() {
return new boolean[] { false };
}
@Override
public String[] getNames() {
return new String[] { "SilhCoeff" };
}
public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) {
int numFCluster = clustering.size();
double[][] pointInclusionProbFC = new double[points.size()][numFCluster];
for (int p = 0; p < points.size(); p++) {
DataPoint point = points.get(p);
for (int fc = 0; fc < numFCluster; fc++) {
Cluster cl = clustering.get(fc);
pointInclusionProbFC[p][fc] = cl.getInclusionProbability(point);
}
}
double silhCoeff = 0.0;
int totalCount = 0;
for (int p = 0; p < points.size(); p++) {
DataPoint point = points.get(p);
ArrayList<Integer> ownClusters = new ArrayList<>();
for (int fc = 0; fc < numFCluster; fc++) {
if (pointInclusionProbFC[p][fc] > pointInclusionProbThreshold) {
ownClusters.add(fc);
}
}
if (ownClusters.size() > 0) {
double[] distanceByClusters = new double[numFCluster];
int[] countsByClusters = new int[numFCluster];
// calculate averageDistance of p to all cluster
for (int p1 = 0; p1 < points.size(); p1++) {
DataPoint point1 = points.get(p1);
if (p1 != p && point1.classValue() != -1) {
for (int fc = 0; fc < numFCluster; fc++) {
if (pointInclusionProbFC[p1][fc] > pointInclusionProbThreshold) {
double distance = point.getDistance(point1);
distanceByClusters[fc] += distance;
countsByClusters[fc]++;
}
}
}
}
// find closest OWN cluster as clusters might overlap
double minAvgDistanceOwn = Double.MAX_VALUE;
int minOwnIndex = -1;
for (int fc : ownClusters) {
double normDist = distanceByClusters[fc] / (double) countsByClusters[fc];
if (normDist < minAvgDistanceOwn) {// && pointInclusionProbFC[p][fc] > pointInclusionProbThreshold){
minAvgDistanceOwn = normDist;
minOwnIndex = fc;
}
}
// find closest other (or other own) cluster
double minAvgDistanceOther = Double.MAX_VALUE;
for (int fc = 0; fc < numFCluster; fc++) {
if (fc != minOwnIndex) {
double normDist = distanceByClusters[fc] / (double) countsByClusters[fc];
if (normDist < minAvgDistanceOther) {
minAvgDistanceOther = normDist;
}
}
}
double silhP = (minAvgDistanceOther - minAvgDistanceOwn) / Math.max(minAvgDistanceOther, minAvgDistanceOwn);
point.setMeasureValue("SC - own", minAvgDistanceOwn);
point.setMeasureValue("SC - other", minAvgDistanceOther);
point.setMeasureValue("SC", silhP);
silhCoeff += silhP;
totalCount++;
// System.out.println(point.getTimestamp()+" Silh "+silhP+" / "+avgDistanceOwn+" "+minAvgDistanceOther+" (C"+minIndex+")");
}
}
if (totalCount > 0)
silhCoeff /= (double) totalCount;
// normalize from -1, 1 to 0,1
silhCoeff = (silhCoeff + 1) / 2.0;
addValue(0, silhCoeff);
}
}