blob: d1aa9a789babf44f927c8919c3fdd6b2557b301d [file] [log] [blame]
/*
* Copyright (c) 2011 The S4 Project, http://s4.io.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the specific
* language governing permissions and limitations under the
* License. See accompanying LICENSE file.
*/
package io.s4.example.kmeans;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.s4.App;
import io.s4.Event;
import io.s4.ProcessingElement;
import io.s4.Stream;
public class ClusterPE extends ProcessingElement {
Logger logger = LoggerFactory.getLogger(ClusterPE.class);
final private int numClusters;
final private int vectorSize;
final private long numVectors;
private Stream<ObsEvent> distanceStream;
private int clusterId;
private float[] centroid;
private long obsCount = 0;
private float[] obsSum;
private float totalDistance = 0f;
private int[][] confusionMatrix;
public ClusterPE(App app, int numClusters, int vectorSize, long numVectors,
float[][] centroids) {
super(app);
this.numClusters = numClusters;
this.vectorSize = vectorSize;
this.numVectors = numVectors;
confusionMatrix = new int[numClusters][numClusters];
/*
* The ClusterPE instances are not event driven. That is they are
* created here before events start to flow.
*
* The total number of PE instances is given by numClusters.
*/
for (int i = 0; i < numClusters; i++) {
ClusterPE pe = (ClusterPE) this.getInstanceForKey(Integer
.toString(i));
pe.setClusterId(i);
pe.setCentroid(centroids[i]);
}
}
public void setStream(Stream<ObsEvent> distanceStream) {
/* Init prototype. */
this.distanceStream = distanceStream;
/*
* We also need to set the stream in the instances we created in the
* constructor.
*/
List<ProcessingElement> pes = this.getAllInstances();
/* STEP 2: iterate and pass event to PE instance. */
for (ProcessingElement pe : pes) {
((ClusterPE) pe).distanceStream = distanceStream;
}
}
public void setClusterId(int clusterId) {
this.clusterId = clusterId;
}
public void setCentroid(float[] centroid) {
this.centroid = centroid;
}
public long getObsCount() {
return obsCount;
}
synchronized private void updateTotalStats(ObsEvent event) {
/* Update global stats in the prototype in the prototype. */
ClusterPE clusterPEPrototype = (ClusterPE) pePrototype;
clusterPEPrototype.obsCount++;
clusterPEPrototype.totalDistance += event.getDistance();
clusterPEPrototype.confusionMatrix[event.getClassId()][event.getHypId()] += 1;
logger.trace("Index: " + event.getIndex() + ", Label: "
+ event.getClassId() + ", Hyp: " + event.getHypId()
+ ", Total Count: " + clusterPEPrototype.obsCount
+ ", Total Dist: " + clusterPEPrototype.totalDistance);
/* Log info. */
if (clusterPEPrototype.obsCount % 10000 == 0) {
logger.info("Processed {} events", clusterPEPrototype.obsCount);
logger.info("Average distance is {}.",
clusterPEPrototype.totalDistance
/ clusterPEPrototype.obsCount);
}
if (clusterPEPrototype.obsCount == numVectors) {
/* Done processing training set. */
logger.info("Final Count: {}.", clusterPEPrototype.obsCount);
logger.info("Final Average Distance: {}.",
clusterPEPrototype.totalDistance
/ clusterPEPrototype.obsCount);
for (int i = 0; i < numClusters; i++)
for (int j = 0; j < numClusters; j++) {
Object[] paramArray = { i, j,
clusterPEPrototype.confusionMatrix[i][j] };
logger.info(
"Final Count of class {} classified as {}: {}.",
paramArray);
}
/* Update centroids. */
for (Map.Entry<String, ProcessingElement> entry : peInstances
.entrySet()) {
ClusterPE pe = (ClusterPE) entry.getValue();
// String key = entry.getKey();
pe.updateCentroid();
}
/* Reset global stats. */
clusterPEPrototype.obsCount = 0;
clusterPEPrototype.totalDistance = 0f;
clusterPEPrototype.confusionMatrix = new int[numClusters][numClusters];
}
}
/*
* Compute Euclidean distance between an observed vectors and the centroid.
*/
private float distance(float[] obs) {
float sumSq = 0f;
for (int i = 0; i < vectorSize; i++) {
float diff = centroid[i] - obs[i];
sumSq += diff * diff;
}
return (float) Math.sqrt(sumSq);
}
private void updateCentroid() {
for (int i = 0; i < vectorSize; i++) {
centroid[i] = obsSum[i] / obsCount;
obsSum[i] = 0f;
}
obsCount = 0;
}
/*
*
* @see io.s4.ProcessingElement#processInputEvent(io.s4.Event)
*
* Read input event, compute distance to current centroid and emit.
*/
@Override
protected void processInputEvent(Event event) {
ObsEvent inEvent = (ObsEvent) event;
float[] obs = inEvent.getObsVector();
/* The raw ObsEvent should have the distance set to less than 0.0. */
if (inEvent.getDistance() < 0f) {
/* Process raw event. */
float dist = distance(obs);
ObsEvent outEvent = new ObsEvent(inEvent.getIndex(), obs, dist,
inEvent.getClassId(), clusterId);
logger.trace("IN: " + inEvent.toString());
logger.trace("OUT: " + outEvent.toString());
distanceStream.put(outEvent);
} else {
/* This is a labeled event. Update sufficient statistics. */
logger.trace("LABELED IN: " + inEvent.toString());
/* Update obs count for this class. */
obsCount++;
/* Log info. */
if (obsCount % 1000 == 0) {
logger.info("Labeled {} events with class id {}", obsCount,
clusterId);
}
/* Update total obs count and distance. */
updateTotalStats(inEvent);
for (int i = 0; i < vectorSize; i++) {
obsSum[i] += obs[i];
}
}
}
@Override
public void sendEvent() {
// TODO Auto-generated method stub
}
@Override
protected void initPEInstance() {
/* Create an array for each PE instance. */
this.obsSum = new float[vectorSize];
}
@Override
protected void removeInstanceForKey(String id) {
// TODO Auto-generated method stub
}
}