Format java source files according to the eclipse-format.xml standard
- spaces, no tabs
- indent size = 2 spaces
- line wrap at 120
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/ContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/ContentEvent.java
index a3ef92a..d9bb944 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/ContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/ContentEvent.java
@@ -24,20 +24,21 @@
* The Interface ContentEvent.
*/
public interface ContentEvent extends java.io.Serializable {
-
- /**
- * Gets the content event key.
- *
- * @return the key
- */
- public String getKey();
-
- /**
- * Sets the content event key.
- *
- * @param key string
- */
- public void setKey(String key);
-
- public boolean isLastEvent();
+
+ /**
+ * Gets the content event key.
+ *
+ * @return the key
+ */
+ public String getKey();
+
+ /**
+ * Sets the content event key.
+ *
+ * @param key
+ * string
+ */
+ public void setKey(String key);
+
+ public boolean isLastEvent();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/DoubleVector.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/DoubleVector.java
index 39362b5..6be9452 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/DoubleVector.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/DoubleVector.java
@@ -26,94 +26,94 @@
public class DoubleVector implements java.io.Serializable {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 8243012708860261398L;
+ private static final long serialVersionUID = 8243012708860261398L;
- private double[] doubleArray;
+ private double[] doubleArray;
- public DoubleVector() {
- this.doubleArray = new double[0];
+ public DoubleVector() {
+ this.doubleArray = new double[0];
+ }
+
+ public DoubleVector(double[] toCopy) {
+ this.doubleArray = new double[toCopy.length];
+ System.arraycopy(toCopy, 0, this.doubleArray, 0, toCopy.length);
+ }
+
+ public DoubleVector(DoubleVector toCopy) {
+ this(toCopy.getArrayRef());
+ }
+
+ public double[] getArrayRef() {
+ return this.doubleArray;
+ }
+
+ public double[] getArrayCopy() {
+ return Doubles.concat(this.doubleArray);
+ }
+
+ public int numNonZeroEntries() {
+ int count = 0;
+ for (double element : this.doubleArray) {
+ if (Double.compare(element, 0.0) != 0) {
+ count++;
+ }
}
+ return count;
+ }
- public DoubleVector(double[] toCopy) {
- this.doubleArray = new double[toCopy.length];
- System.arraycopy(toCopy, 0, this.doubleArray, 0, toCopy.length);
+ public void setValue(int index, double value) {
+ if (index >= doubleArray.length) {
+ this.doubleArray = Doubles.ensureCapacity(this.doubleArray, index + 1, 0);
}
+ this.doubleArray[index] = value;
+ }
- public DoubleVector(DoubleVector toCopy) {
- this(toCopy.getArrayRef());
+ public void addToValue(int index, double value) {
+ if (index >= doubleArray.length) {
+ this.doubleArray = Doubles.ensureCapacity(this.doubleArray, index + 1, 0);
}
+ this.doubleArray[index] += value;
+ }
- public double[] getArrayRef() {
- return this.doubleArray;
+ public double sumOfValues() {
+ double sum = 0.0;
+ for (double element : this.doubleArray) {
+ sum += element;
}
+ return sum;
+ }
- public double[] getArrayCopy() {
- return Doubles.concat(this.doubleArray);
- }
+ public void getSingleLineDescription(StringBuilder out) {
+ out.append("{");
+ out.append(Doubles.join("|", this.doubleArray));
+ out.append("}");
+ }
- public int numNonZeroEntries() {
- int count = 0;
- for (double element : this.doubleArray) {
- if (Double.compare(element, 0.0) != 0) {
- count++;
- }
- }
- return count;
- }
+ @Override
+ public String toString() {
+ return "DoubleVector [doubleArray=" + Arrays.toString(doubleArray) + "]";
+ }
- public void setValue(int index, double value) {
- if (index >= doubleArray.length) {
- this.doubleArray = Doubles.ensureCapacity(this.doubleArray, index + 1, 0);
- }
- this.doubleArray[index] = value;
- }
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + Arrays.hashCode(doubleArray);
+ return result;
+ }
- public void addToValue(int index, double value) {
- if (index >= doubleArray.length) {
- this.doubleArray = Doubles.ensureCapacity(this.doubleArray, index + 1, 0);
- }
- this.doubleArray[index] += value;
- }
-
- public double sumOfValues() {
- double sum = 0.0;
- for (double element : this.doubleArray) {
- sum += element;
- }
- return sum;
- }
-
- public void getSingleLineDescription(StringBuilder out) {
- out.append("{");
- out.append(Doubles.join("|", this.doubleArray));
- out.append("}");
- }
-
- @Override
- public String toString() {
- return "DoubleVector [doubleArray=" + Arrays.toString(doubleArray) + "]";
- }
-
- @Override
- public int hashCode() {
- final int prime = 31;
- int result = 1;
- result = prime * result + Arrays.hashCode(doubleArray);
- return result;
- }
-
- @Override
- public boolean equals(Object obj) {
- if (this == obj)
- return true;
- if (obj == null)
- return false;
- if (!(obj instanceof DoubleVector))
- return false;
- DoubleVector other = (DoubleVector) obj;
- return Arrays.equals(doubleArray, other.doubleArray);
- }
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+ if (obj == null)
+ return false;
+ if (!(obj instanceof DoubleVector))
+ return false;
+ DoubleVector other = (DoubleVector) obj;
+ return Arrays.equals(doubleArray, other.doubleArray);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/EntranceProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/EntranceProcessor.java
index e1bdc14..e9eac30 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/EntranceProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/EntranceProcessor.java
@@ -25,35 +25,40 @@
import com.github.javacliparser.Configurable;
/**
- * An EntranceProcessor is a specific kind of processor dedicated to providing events to inject in the topology. It can be connected to a single output stream.
+ * An EntranceProcessor is a specific kind of processor dedicated to providing
+ * events to inject in the topology. It can be connected to a single output
+ * stream.
*/
public interface EntranceProcessor extends Serializable, Configurable, Processor {
- /**
- * Initializes the Processor. This method is called once after the topology is set up and before any call to the {@link nextTuple} method.
- *
- * @param the
- * identifier of the processor.
- */
- public void onCreate(int id);
-
- /**
- * Checks whether the source stream is finished/exhausted.
- */
- public boolean isFinished();
-
- /**
- * Checks whether a new event is ready to be processed.
- *
- * @return true if the EntranceProcessor is ready to provide the next event, false otherwise.
- */
- public boolean hasNext();
+ /**
+ * Initializes the Processor. This method is called once after the topology is
+ * set up and before any call to the {@link nextTuple} method.
+ *
+ * @param the
+ * identifier of the processor.
+ */
+ public void onCreate(int id);
- /**
- * Provides the next tuple to be processed by the topology. This method is the entry point for external events into the topology.
- *
- * @return the next event to be processed.
- */
- public ContentEvent nextEvent();
+ /**
+ * Checks whether the source stream is finished/exhausted.
+ */
+ public boolean isFinished();
+
+ /**
+ * Checks whether a new event is ready to be processed.
+ *
+ * @return true if the EntranceProcessor is ready to provide the next event,
+ * false otherwise.
+ */
+ public boolean hasNext();
+
+ /**
+ * Provides the next tuple to be processed by the topology. This method is the
+ * entry point for external events into the topology.
+ *
+ * @return the next event to be processed.
+ */
+ public ContentEvent nextEvent();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/Globals.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/Globals.java
index 8e04016..e3435c1 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/Globals.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/Globals.java
@@ -34,26 +34,26 @@
*/
public class Globals {
- public static final String workbenchTitle = "SAMOA: Scalable Advanced Massive Online Analysis Platform ";
+ public static final String workbenchTitle = "SAMOA: Scalable Advanced Massive Online Analysis Platform ";
- public static final String versionString = "0.0.1";
+ public static final String versionString = "0.0.1";
- public static final String copyrightNotice = "Copyright Yahoo! Inc 2013";
+ public static final String copyrightNotice = "Copyright Yahoo! Inc 2013";
- public static final String webAddress = "http://github.com/yahoo/samoa";
+ public static final String webAddress = "http://github.com/yahoo/samoa";
- public static String getWorkbenchInfoString() {
- StringBuilder result = new StringBuilder();
- result.append(workbenchTitle);
- StringUtils.appendNewline(result);
- result.append("Version: ");
- result.append(versionString);
- StringUtils.appendNewline(result);
- result.append("Copyright: ");
- result.append(copyrightNotice);
- StringUtils.appendNewline(result);
- result.append("Web: ");
- result.append(webAddress);
- return result.toString();
- }
+ public static String getWorkbenchInfoString() {
+ StringBuilder result = new StringBuilder();
+ result.append(workbenchTitle);
+ StringUtils.appendNewline(result);
+ result.append("Version: ");
+ result.append(versionString);
+ StringUtils.appendNewline(result);
+ result.append("Copyright: ");
+ result.append(copyrightNotice);
+ StringUtils.appendNewline(result);
+ result.append("Web: ");
+ result.append(webAddress);
+ return result.toString();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/Processor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/Processor.java
index 2033fae..b02d33c 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/Processor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/Processor.java
@@ -29,33 +29,36 @@
*/
public interface Processor extends Serializable, Configurable {
- /**
- * Entry point for the {@link Processor} code. This method is called once for every event received.
- *
- * @param event
- * the event to be processed.
- * @return true if successful, false otherwise.
- */
- boolean process(ContentEvent event);
+ /**
+ * Entry point for the {@link Processor} code. This method is called once for
+ * every event received.
+ *
+ * @param event
+ * the event to be processed.
+ * @return true if successful, false otherwise.
+ */
+ boolean process(ContentEvent event);
- /**
- * Initializes the Processor.
- * This method is called once after the topology is set up and before any call to the {@link process} method.
- *
- * @param id
- * the identifier of the processor.
- */
- void onCreate(int id);
+ /**
+ * Initializes the Processor. This method is called once after the topology is
+ * set up and before any call to the {@link process} method.
+ *
+ * @param id
+ * the identifier of the processor.
+ */
+ void onCreate(int id);
- /**
- * Creates a copy of a processor.
- * This method is used to instantiate multiple instances of the same {@link Processsor}.
- *
- * @param processor
- * the processor to be copied.
- *
- * @return a new instance of the {@link Processor}.
- * */
- Processor newProcessor(Processor processor); // FIXME there should be no need for the processor as a parameter
- // TODO can we substitute this with Cloneable?
+ /**
+ * Creates a copy of a processor. This method is used to instantiate multiple
+ * instances of the same {@link Processsor}.
+ *
+ * @param processor
+ * the processor to be copied.
+ *
+ * @return a new instance of the {@link Processor}.
+ * */
+ Processor newProcessor(Processor processor); // FIXME there should be no need
+ // for the processor as a
+ // parameter
+ // TODO can we substitute this with Cloneable?
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/SerializableInstance.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/SerializableInstance.java
index 715c656..a4d5b24 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/core/SerializableInstance.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/core/SerializableInstance.java
@@ -31,37 +31,39 @@
//import weka.core.Instance;
/**
- * The Class SerializableInstance.
- * This class is needed for serialization of kryo
+ * The Class SerializableInstance. This class is needed for serialization of
+ * kryo
*/
public class SerializableInstance extends DenseInstance {
- /** The Constant serialVersionUID. */
- private static final long serialVersionUID = -3659459626274566468L;
+ /** The Constant serialVersionUID. */
+ private static final long serialVersionUID = -3659459626274566468L;
- /**
- * Instantiates a new serializable instance.
- */
- public SerializableInstance() {
- super(0);
- }
+ /**
+ * Instantiates a new serializable instance.
+ */
+ public SerializableInstance() {
+ super(0);
+ }
- /**
- * Instantiates a new serializable instance.
- *
- * @param arg0 the arg0
- */
- public SerializableInstance(int arg0) {
- super(arg0);
- }
+ /**
+ * Instantiates a new serializable instance.
+ *
+ * @param arg0
+ * the arg0
+ */
+ public SerializableInstance(int arg0) {
+ super(arg0);
+ }
- /**
- * Instantiates a new serializable instance.
- *
- * @param inst the inst
- */
- public SerializableInstance(Instance inst) {
- super(inst);
- }
+ /**
+ * Instantiates a new serializable instance.
+ *
+ * @param inst
+ * the inst
+ */
+ public SerializableInstance(Instance inst) {
+ super(inst);
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/BasicClassificationPerformanceEvaluator.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/BasicClassificationPerformanceEvaluator.java
index 89a89c0..bc1d447 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/BasicClassificationPerformanceEvaluator.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/BasicClassificationPerformanceEvaluator.java
@@ -33,125 +33,125 @@
* @version $Revision: 7 $
*/
public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject implements
- ClassificationPerformanceEvaluator {
+ ClassificationPerformanceEvaluator {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected double weightObserved;
+ protected double weightObserved;
- protected double weightCorrect;
+ protected double weightCorrect;
- protected double[] columnKappa;
+ protected double[] columnKappa;
- protected double[] rowKappa;
+ protected double[] rowKappa;
- protected int numClasses;
+ protected int numClasses;
- private double weightCorrectNoChangeClassifier;
+ private double weightCorrectNoChangeClassifier;
- private int lastSeenClass;
+ private int lastSeenClass;
- @Override
- public void reset() {
- reset(this.numClasses);
+ @Override
+ public void reset() {
+ reset(this.numClasses);
+ }
+
+ public void reset(int numClasses) {
+ this.numClasses = numClasses;
+ this.rowKappa = new double[numClasses];
+ this.columnKappa = new double[numClasses];
+ for (int i = 0; i < this.numClasses; i++) {
+ this.rowKappa[i] = 0.0;
+ this.columnKappa[i] = 0.0;
}
+ this.weightObserved = 0.0;
+ this.weightCorrect = 0.0;
+ this.weightCorrectNoChangeClassifier = 0.0;
+ this.lastSeenClass = 0;
+ }
- public void reset(int numClasses) {
- this.numClasses = numClasses;
- this.rowKappa = new double[numClasses];
- this.columnKappa = new double[numClasses];
- for (int i = 0; i < this.numClasses; i++) {
- this.rowKappa[i] = 0.0;
- this.columnKappa[i] = 0.0;
- }
- this.weightObserved = 0.0;
- this.weightCorrect = 0.0;
- this.weightCorrectNoChangeClassifier = 0.0;
- this.lastSeenClass = 0;
+ @Override
+ public void addResult(Instance inst, double[] classVotes) {
+ double weight = inst.weight();
+ int trueClass = (int) inst.classValue();
+ if (weight > 0.0) {
+ if (this.weightObserved == 0) {
+ reset(inst.numClasses());
+ }
+ this.weightObserved += weight;
+ int predictedClass = Utils.maxIndex(classVotes);
+ if (predictedClass == trueClass) {
+ this.weightCorrect += weight;
+ }
+ if (rowKappa.length > 0) {
+ this.rowKappa[predictedClass] += weight;
+ }
+ if (columnKappa.length > 0) {
+ this.columnKappa[trueClass] += weight;
+ }
}
-
- @Override
- public void addResult(Instance inst, double[] classVotes) {
- double weight = inst.weight();
- int trueClass = (int) inst.classValue();
- if (weight > 0.0) {
- if (this.weightObserved == 0) {
- reset(inst.numClasses());
- }
- this.weightObserved += weight;
- int predictedClass = Utils.maxIndex(classVotes);
- if (predictedClass == trueClass) {
- this.weightCorrect += weight;
- }
- if(rowKappa.length > 0){
- this.rowKappa[predictedClass] += weight;
- }
- if (columnKappa.length > 0) {
- this.columnKappa[trueClass] += weight;
- }
- }
- if (this.lastSeenClass == trueClass) {
- this.weightCorrectNoChangeClassifier += weight;
- }
- this.lastSeenClass = trueClass;
+ if (this.lastSeenClass == trueClass) {
+ this.weightCorrectNoChangeClassifier += weight;
}
+ this.lastSeenClass = trueClass;
+ }
- @Override
- public Measurement[] getPerformanceMeasurements() {
- return new Measurement[]{
- new Measurement("classified instances",
+ @Override
+ public Measurement[] getPerformanceMeasurements() {
+ return new Measurement[] {
+ new Measurement("classified instances",
getTotalWeightObserved()),
- new Measurement("classifications correct (percent)",
+ new Measurement("classifications correct (percent)",
getFractionCorrectlyClassified() * 100.0),
- new Measurement("Kappa Statistic (percent)",
+ new Measurement("Kappa Statistic (percent)",
getKappaStatistic() * 100.0),
- new Measurement("Kappa Temporal Statistic (percent)",
+ new Measurement("Kappa Temporal Statistic (percent)",
getKappaTemporalStatistic() * 100.0)
- };
+ };
+ }
+
+ public double getTotalWeightObserved() {
+ return this.weightObserved;
+ }
+
+ public double getFractionCorrectlyClassified() {
+ return this.weightObserved > 0.0 ? this.weightCorrect
+ / this.weightObserved : 0.0;
+ }
+
+ public double getFractionIncorrectlyClassified() {
+ return 1.0 - getFractionCorrectlyClassified();
+ }
+
+ public double getKappaStatistic() {
+ if (this.weightObserved > 0.0) {
+ double p0 = getFractionCorrectlyClassified();
+ double pc = 0.0;
+ for (int i = 0; i < this.numClasses; i++) {
+ pc += (this.rowKappa[i] / this.weightObserved)
+ * (this.columnKappa[i] / this.weightObserved);
+ }
+ return (p0 - pc) / (1.0 - pc);
+ } else {
+ return 0;
}
+ }
- public double getTotalWeightObserved() {
- return this.weightObserved;
+ public double getKappaTemporalStatistic() {
+ if (this.weightObserved > 0.0) {
+ double p0 = this.weightCorrect / this.weightObserved;
+ double pc = this.weightCorrectNoChangeClassifier / this.weightObserved;
+
+ return (p0 - pc) / (1.0 - pc);
+ } else {
+ return 0;
}
+ }
- public double getFractionCorrectlyClassified() {
- return this.weightObserved > 0.0 ? this.weightCorrect
- / this.weightObserved : 0.0;
- }
-
- public double getFractionIncorrectlyClassified() {
- return 1.0 - getFractionCorrectlyClassified();
- }
-
- public double getKappaStatistic() {
- if (this.weightObserved > 0.0) {
- double p0 = getFractionCorrectlyClassified();
- double pc = 0.0;
- for (int i = 0; i < this.numClasses; i++) {
- pc += (this.rowKappa[i] / this.weightObserved)
- * (this.columnKappa[i] / this.weightObserved);
- }
- return (p0 - pc) / (1.0 - pc);
- } else {
- return 0;
- }
- }
-
- public double getKappaTemporalStatistic() {
- if (this.weightObserved > 0.0) {
- double p0 = this.weightCorrect / this.weightObserved;
- double pc = this.weightCorrectNoChangeClassifier / this.weightObserved;
-
- return (p0 - pc) / (1.0 - pc);
- } else {
- return 0;
- }
- }
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
- sb, indent);
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
+ sb, indent);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/BasicRegressionPerformanceEvaluator.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/BasicRegressionPerformanceEvaluator.java
index d98fe72..b5f318e 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/BasicRegressionPerformanceEvaluator.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/BasicRegressionPerformanceEvaluator.java
@@ -26,109 +26,109 @@
/**
* Regression evaluator that performs basic incremental evaluation.
- *
+ *
* @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
* @version $Revision: 7 $
*/
public class BasicRegressionPerformanceEvaluator extends AbstractMOAObject
- implements RegressionPerformanceEvaluator {
+ implements RegressionPerformanceEvaluator {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected double weightObserved;
+ protected double weightObserved;
- protected double squareError;
+ protected double squareError;
- protected double averageError;
-
- protected double sumTarget;
-
- protected double squareTargetError;
-
- protected double averageTargetError;
+ protected double averageError;
- @Override
- public void reset() {
- this.weightObserved = 0.0;
- this.squareError = 0.0;
- this.averageError = 0.0;
- this.sumTarget = 0.0;
- this.averageTargetError = 0.0;
- this.squareTargetError = 0.0;
-
+ protected double sumTarget;
+
+ protected double squareTargetError;
+
+ protected double averageTargetError;
+
+ @Override
+ public void reset() {
+ this.weightObserved = 0.0;
+ this.squareError = 0.0;
+ this.averageError = 0.0;
+ this.sumTarget = 0.0;
+ this.averageTargetError = 0.0;
+ this.squareTargetError = 0.0;
+
+ }
+
+ @Override
+ public void addResult(Instance inst, double[] prediction) {
+ double weight = inst.weight();
+ double classValue = inst.classValue();
+ if (weight > 0.0) {
+ if (prediction.length > 0) {
+ double meanTarget = this.weightObserved != 0 ?
+ this.sumTarget / this.weightObserved : 0.0;
+ this.squareError += (classValue - prediction[0]) * (classValue - prediction[0]);
+ this.averageError += Math.abs(classValue - prediction[0]);
+ this.squareTargetError += (classValue - meanTarget) * (classValue - meanTarget);
+ this.averageTargetError += Math.abs(classValue - meanTarget);
+ this.sumTarget += classValue;
+ this.weightObserved += weight;
+ }
}
+ }
- @Override
- public void addResult(Instance inst, double[] prediction) {
- double weight = inst.weight();
- double classValue = inst.classValue();
- if (weight > 0.0) {
- if (prediction.length > 0) {
- double meanTarget = this.weightObserved != 0 ?
- this.sumTarget / this.weightObserved : 0.0;
- this.squareError += (classValue - prediction[0]) * (classValue - prediction[0]);
- this.averageError += Math.abs(classValue - prediction[0]);
- this.squareTargetError += (classValue - meanTarget) * (classValue - meanTarget);
- this.averageTargetError += Math.abs(classValue - meanTarget);
- this.sumTarget += classValue;
- this.weightObserved += weight;
- }
- }
- }
+ @Override
+ public Measurement[] getPerformanceMeasurements() {
+ return new Measurement[] {
+ new Measurement("classified instances",
+ getTotalWeightObserved()),
+ new Measurement("mean absolute error",
+ getMeanError()),
+ new Measurement("root mean squared error",
+ getSquareError()),
+ new Measurement("relative mean absolute error",
+ getRelativeMeanError()),
+ new Measurement("relative root mean squared error",
+ getRelativeSquareError())
+ };
+ }
- @Override
- public Measurement[] getPerformanceMeasurements() {
- return new Measurement[]{
- new Measurement("classified instances",
- getTotalWeightObserved()),
- new Measurement("mean absolute error",
- getMeanError()),
- new Measurement("root mean squared error",
- getSquareError()),
- new Measurement("relative mean absolute error",
- getRelativeMeanError()),
- new Measurement("relative root mean squared error",
- getRelativeSquareError())
- };
- }
+ public double getTotalWeightObserved() {
+ return this.weightObserved;
+ }
- public double getTotalWeightObserved() {
- return this.weightObserved;
- }
+ public double getMeanError() {
+ return this.weightObserved > 0.0 ? this.averageError
+ / this.weightObserved : 0.0;
+ }
- public double getMeanError() {
- return this.weightObserved > 0.0 ? this.averageError
- / this.weightObserved : 0.0;
- }
+ public double getSquareError() {
+ return Math.sqrt(this.weightObserved > 0.0 ? this.squareError
+ / this.weightObserved : 0.0);
+ }
- public double getSquareError() {
- return Math.sqrt(this.weightObserved > 0.0 ? this.squareError
- / this.weightObserved : 0.0);
- }
-
- public double getTargetMeanError() {
- return this.weightObserved > 0.0 ? this.averageTargetError
- / this.weightObserved : 0.0;
- }
+ public double getTargetMeanError() {
+ return this.weightObserved > 0.0 ? this.averageTargetError
+ / this.weightObserved : 0.0;
+ }
- public double getTargetSquareError() {
- return Math.sqrt(this.weightObserved > 0.0 ? this.squareTargetError
- / this.weightObserved : 0.0);
- }
+ public double getTargetSquareError() {
+ return Math.sqrt(this.weightObserved > 0.0 ? this.squareTargetError
+ / this.weightObserved : 0.0);
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
- sb, indent);
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
+ sb, indent);
+ }
- private double getRelativeMeanError() {
- return this.averageTargetError> 0 ?
- this.averageError/this.averageTargetError : 0.0;
- }
+ private double getRelativeMeanError() {
+ return this.averageTargetError > 0 ?
+ this.averageError / this.averageTargetError : 0.0;
+ }
- private double getRelativeSquareError() {
- return Math.sqrt(this.squareTargetError> 0 ?
- this.squareError/this.squareTargetError : 0.0);
- }
+ private double getRelativeSquareError() {
+ return Math.sqrt(this.squareTargetError > 0 ?
+ this.squareError / this.squareTargetError : 0.0);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringEvaluationContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringEvaluationContentEvent.java
index 27fee6a..d482145 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringEvaluationContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringEvaluationContentEvent.java
@@ -32,51 +32,54 @@
*/
final public class ClusteringEvaluationContentEvent implements ContentEvent {
- private static final long serialVersionUID = -7746983521296618922L;
- private Clustering gtClustering;
- private DataPoint dataPoint;
- private final boolean isLast;
- private String key = "0";
+ private static final long serialVersionUID = -7746983521296618922L;
+ private Clustering gtClustering;
+ private DataPoint dataPoint;
+ private final boolean isLast;
+ private String key = "0";
- public ClusteringEvaluationContentEvent() {
- this.isLast = false;
- }
+ public ClusteringEvaluationContentEvent() {
+ this.isLast = false;
+ }
- public ClusteringEvaluationContentEvent(boolean isLast) {
- this.isLast = isLast;
- }
+ public ClusteringEvaluationContentEvent(boolean isLast) {
+ this.isLast = isLast;
+ }
- /**
- * Instantiates a new gtClustering result event.
- *
- * @param clustering the gtClustering result
- * @param instance data point
- * @param isLast is the last result
- */
- public ClusteringEvaluationContentEvent(Clustering clustering, DataPoint instance, boolean isLast) {
- this.gtClustering = clustering;
- this.isLast = isLast;
- this.dataPoint = instance;
- }
+ /**
+ * Instantiates a new gtClustering result event.
+ *
+ * @param clustering
+ * the gtClustering result
+ * @param instance
+ * data point
+ * @param isLast
+ * is the last result
+ */
+ public ClusteringEvaluationContentEvent(Clustering clustering, DataPoint instance, boolean isLast) {
+ this.gtClustering = clustering;
+ this.isLast = isLast;
+ this.dataPoint = instance;
+ }
- public String getKey() {
- return key;
- }
+ public String getKey() {
+ return key;
+ }
- public void setKey(String key) {
- this.key = key;
- }
+ public void setKey(String key) {
+ this.key = key;
+ }
- public boolean isLastEvent() {
- return this.isLast;
- }
+ public boolean isLastEvent() {
+ return this.isLast;
+ }
- Clustering getGTClustering() {
- return this.gtClustering;
- }
-
- DataPoint getDataPoint() {
- return this.dataPoint;
- }
-
+ Clustering getGTClustering() {
+ return this.gtClustering;
+ }
+
+ DataPoint getDataPoint() {
+ return this.dataPoint;
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringEvaluatorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringEvaluatorProcessor.java
index 2525a04..d8e0943 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringEvaluatorProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringEvaluatorProcessor.java
@@ -45,275 +45,277 @@
public class ClusteringEvaluatorProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -2778051819116753612L;
+ private static final long serialVersionUID = -2778051819116753612L;
- private static final Logger logger = LoggerFactory.getLogger(EvaluatorProcessor.class);
+ private static final Logger logger = LoggerFactory.getLogger(EvaluatorProcessor.class);
- private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances";
+ private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances";
- private final int samplingFrequency;
- private final int decayHorizon;
- private final File dumpFile;
- private transient PrintStream immediateResultStream = null;
- private transient boolean firstDump = true;
+ private final int samplingFrequency;
+ private final int decayHorizon;
+ private final File dumpFile;
+ private transient PrintStream immediateResultStream = null;
+ private transient boolean firstDump = true;
- private long totalCount = 0;
- private long experimentStart = 0;
+ private long totalCount = 0;
+ private long experimentStart = 0;
- private LearningCurve learningCurve;
+ private LearningCurve learningCurve;
- private MeasureCollection[] measures;
+ private MeasureCollection[] measures;
- private int id;
+ private int id;
- protected Clustering gtClustering;
+ protected Clustering gtClustering;
- protected ArrayList<DataPoint> points;
+ protected ArrayList<DataPoint> points;
- private ClusteringEvaluatorProcessor(Builder builder) {
- this.samplingFrequency = builder.samplingFrequency;
- this.dumpFile = builder.dumpFile;
- this.points = new ArrayList<>();
- this.decayHorizon = builder.decayHorizon;
+ private ClusteringEvaluatorProcessor(Builder builder) {
+ this.samplingFrequency = builder.samplingFrequency;
+ this.dumpFile = builder.dumpFile;
+ this.points = new ArrayList<>();
+ this.decayHorizon = builder.decayHorizon;
+ }
+
+ @Override
+ public boolean process(ContentEvent event) {
+ boolean ret = false;
+ if (event instanceof ClusteringResultContentEvent) {
+ ret = process((ClusteringResultContentEvent) event);
+ }
+ if (event instanceof ClusteringEvaluationContentEvent) {
+ ret = process((ClusteringEvaluationContentEvent) event);
+ }
+ return ret;
+ }
+
+ private boolean process(ClusteringResultContentEvent result) {
+ // evaluate
+ Clustering clustering = KMeans.gaussianMeans(gtClustering, result.getClustering());
+ for (MeasureCollection measure : measures) {
+ try {
+ measure.evaluateClusteringPerformance(clustering, gtClustering, points);
+ } catch (Exception ex) {
+ ex.printStackTrace();
+ }
}
- @Override
- public boolean process(ContentEvent event) {
- boolean ret = false;
- if (event instanceof ClusteringResultContentEvent) {
- ret = process((ClusteringResultContentEvent) event);
- }
- if (event instanceof ClusteringEvaluationContentEvent) {
- ret = process((ClusteringEvaluationContentEvent) event);
- }
- return ret;
+ this.addMeasurement();
+
+ if (result.isLastEvent()) {
+ this.concludeMeasurement();
+ return true;
}
- private boolean process(ClusteringResultContentEvent result) {
- // evaluate
- Clustering clustering = KMeans.gaussianMeans(gtClustering, result.getClustering());
- for (MeasureCollection measure : measures) {
- try {
- measure.evaluateClusteringPerformance(clustering, gtClustering, points);
- } catch (Exception ex) {
- ex.printStackTrace();
- }
- }
+ totalCount += 1;
- this.addMeasurement();
-
- if (result.isLastEvent()) {
- this.concludeMeasurement();
- return true;
- }
-
- totalCount += 1;
-
- if (totalCount == 1) {
- experimentStart = System.nanoTime();
- }
-
- return false;
+ if (totalCount == 1) {
+ experimentStart = System.nanoTime();
}
- private boolean process(ClusteringEvaluationContentEvent result) {
- boolean ret = false;
- if (result.getGTClustering() != null) {
- gtClustering = result.getGTClustering();
- ret = true;
+ return false;
+ }
+
+ private boolean process(ClusteringEvaluationContentEvent result) {
+ boolean ret = false;
+ if (result.getGTClustering() != null) {
+ gtClustering = result.getGTClustering();
+ ret = true;
+ }
+ if (result.getDataPoint() != null) {
+ points.add(result.getDataPoint());
+ if (points.size() > this.decayHorizon) {
+ points.remove(0);
+ }
+ ret = true;
+ }
+ return ret;
+ }
+
+ @Override
+ public void onCreate(int id) {
+ this.id = id;
+ this.learningCurve = new LearningCurve(ORDERING_MEASUREMENT_NAME);
+ // create the measure collection
+ measures = getMeasures(getMeasureSelection());
+
+ if (this.dumpFile != null) {
+ try {
+ if (dumpFile.exists()) {
+ this.immediateResultStream = new PrintStream(new FileOutputStream(dumpFile, true), true);
+ } else {
+ this.immediateResultStream = new PrintStream(new FileOutputStream(dumpFile), true);
}
- if (result.getDataPoint() != null) {
- points.add(result.getDataPoint());
- if (points.size() > this.decayHorizon) {
- points.remove(0);
- }
- ret = true;
- }
- return ret;
+
+ } catch (FileNotFoundException e) {
+ this.immediateResultStream = null;
+ logger.error("File not found exception for {}:{}", this.dumpFile.getAbsolutePath(), e.toString());
+
+ } catch (Exception e) {
+ this.immediateResultStream = null;
+ logger.error("Exception when creating {}:{}", this.dumpFile.getAbsolutePath(), e.toString());
+ }
}
- @Override
- public void onCreate(int id) {
- this.id = id;
- this.learningCurve = new LearningCurve(ORDERING_MEASUREMENT_NAME);
- // create the measure collection
- measures = getMeasures(getMeasureSelection());
+ this.firstDump = true;
+ }
- if (this.dumpFile != null) {
- try {
- if (dumpFile.exists()) {
- this.immediateResultStream = new PrintStream(new FileOutputStream(dumpFile, true), true);
- } else {
- this.immediateResultStream = new PrintStream(new FileOutputStream(dumpFile), true);
- }
+ private static ArrayList<Class> getMeasureSelection() {
+ ArrayList<Class> mclasses = new ArrayList<>();
+ // mclasses.add(EntropyCollection.class);
+ // mclasses.add(F1.class);
+ // mclasses.add(General.class);
+ // *mclasses.add(CMM.class);
+ mclasses.add(SSQ.class);
+ // *mclasses.add(SilhouetteCoefficient.class);
+ mclasses.add(StatisticalCollection.class);
+ // mclasses.add(Separation.class);
- } catch (FileNotFoundException e) {
- this.immediateResultStream = null;
- logger.error("File not found exception for {}:{}", this.dumpFile.getAbsolutePath(), e.toString());
+ return mclasses;
+ }
- } catch (Exception e) {
- this.immediateResultStream = null;
- logger.error("Exception when creating {}:{}", this.dumpFile.getAbsolutePath(), e.toString());
- }
- }
+ private static MeasureCollection[] getMeasures(ArrayList<Class> measure_classes) {
+ MeasureCollection[] measures = new MeasureCollection[measure_classes.size()];
+ for (int i = 0; i < measure_classes.size(); i++) {
+ try {
+ MeasureCollection m = (MeasureCollection) measure_classes.get(i).newInstance();
+ measures[i] = m;
- this.firstDump = true;
+ } catch (Exception ex) {
+ java.util.logging.Logger.getLogger("Couldn't create Instance for " + measure_classes.get(i).getName());
+ ex.printStackTrace();
+ }
+ }
+ return measures;
+ }
+
+ @Override
+ public Processor newProcessor(Processor p) {
+ ClusteringEvaluatorProcessor originalProcessor = (ClusteringEvaluatorProcessor) p;
+ ClusteringEvaluatorProcessor newProcessor = new ClusteringEvaluatorProcessor.Builder(originalProcessor).build();
+
+ if (originalProcessor.learningCurve != null) {
+ newProcessor.learningCurve = originalProcessor.learningCurve;
}
- private static ArrayList<Class> getMeasureSelection() {
- ArrayList<Class> mclasses = new ArrayList<>();
- // mclasses.add(EntropyCollection.class);
- // mclasses.add(F1.class);
- // mclasses.add(General.class);
- // *mclasses.add(CMM.class);
- mclasses.add(SSQ.class);
- // *mclasses.add(SilhouetteCoefficient.class);
- mclasses.add(StatisticalCollection.class);
- // mclasses.add(Separation.class);
+ return newProcessor;
+ }
- return mclasses;
+ @Override
+ public String toString() {
+ StringBuilder report = new StringBuilder();
+
+ report.append(EvaluatorProcessor.class.getCanonicalName());
+ report.append("id = ").append(this.id);
+ report.append('\n');
+
+ if (learningCurve.numEntries() > 0) {
+ report.append(learningCurve.toString());
+ report.append('\n');
+ }
+ return report.toString();
+ }
+
+ private void addMeasurement() {
+ // printMeasures();
+ List<Measurement> measurements = new ArrayList<>();
+ measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount * this.samplingFrequency));
+
+ addClusteringPerformanceMeasurements(measurements);
+ Measurement[] finalMeasurements = measurements.toArray(new Measurement[measurements.size()]);
+
+ LearningEvaluation learningEvaluation = new LearningEvaluation(finalMeasurements);
+ learningCurve.insertEntry(learningEvaluation);
+ logger.debug("evaluator id = {}", this.id);
+ // logger.info(learningEvaluation.toString());
+
+ if (immediateResultStream != null) {
+ if (firstDump) {
+ immediateResultStream.println(learningCurve.headerToString());
+ firstDump = false;
+ }
+
+ immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1));
+ immediateResultStream.flush();
+ }
+ }
+
+ private void addClusteringPerformanceMeasurements(List<Measurement> measurements) {
+ for (MeasureCollection measure : measures) {
+ for (int j = 0; j < measure.getNumMeasures(); j++) {
+ Measurement measurement = new Measurement(measure.getName(j), measure.getLastValue(j));
+ measurements.add(measurement);
+ }
+ }
+ }
+
+ private void concludeMeasurement() {
+ logger.info("last event is received!");
+ logger.info("total count: {}", this.totalCount);
+
+ String learningCurveSummary = this.toString();
+ logger.info(learningCurveSummary);
+
+ long experimentEnd = System.nanoTime();
+ long totalExperimentTime = TimeUnit.SECONDS.convert(experimentEnd - experimentStart, TimeUnit.NANOSECONDS);
+ logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount);
+ // logger.info("average throughput rate: {} instances/seconds",
+ // (totalCount/totalExperimentTime));
+ }
+
+ private void printMeasures() {
+ StringBuilder sb = new StringBuilder();
+ for (MeasureCollection measure : measures) {
+
+ sb.append("Mean ").append(measure.getClass().getSimpleName()).append(":").append(measure.getNumMeasures())
+ .append("\n");
+ for (int j = 0; j < measure.getNumMeasures(); j++) {
+ sb.append("[").append(measure.getName(j)).append("=").append(measure.getLastValue(j)).append("] \n");
+
+ }
+ sb.append("\n");
}
- private static MeasureCollection[] getMeasures(ArrayList<Class> measure_classes) {
- MeasureCollection[] measures = new MeasureCollection[measure_classes.size()];
- for (int i = 0; i < measure_classes.size(); i++) {
- try {
- MeasureCollection m = (MeasureCollection) measure_classes.get(i).newInstance();
- measures[i] = m;
+ logger.debug("\n MEASURES: \n\n {}", sb.toString());
+ System.out.println(sb.toString());
+ }
- } catch (Exception ex) {
- java.util.logging.Logger.getLogger("Couldn't create Instance for " + measure_classes.get(i).getName());
- ex.printStackTrace();
- }
- }
- return measures;
+ public static class Builder {
+
+ private int samplingFrequency = 1000;
+ private File dumpFile = null;
+ private int decayHorizon = 1000;
+
+ public Builder(int samplingFrequency) {
+ this.samplingFrequency = samplingFrequency;
}
- @Override
- public Processor newProcessor(Processor p) {
- ClusteringEvaluatorProcessor originalProcessor = (ClusteringEvaluatorProcessor) p;
- ClusteringEvaluatorProcessor newProcessor = new ClusteringEvaluatorProcessor.Builder(originalProcessor).build();
-
- if (originalProcessor.learningCurve != null) {
- newProcessor.learningCurve = originalProcessor.learningCurve;
- }
-
- return newProcessor;
+ public Builder(ClusteringEvaluatorProcessor oldProcessor) {
+ this.samplingFrequency = oldProcessor.samplingFrequency;
+ this.dumpFile = oldProcessor.dumpFile;
+ this.decayHorizon = oldProcessor.decayHorizon;
}
- @Override
- public String toString() {
- StringBuilder report = new StringBuilder();
-
- report.append(EvaluatorProcessor.class.getCanonicalName());
- report.append("id = ").append(this.id);
- report.append('\n');
-
- if (learningCurve.numEntries() > 0) {
- report.append(learningCurve.toString());
- report.append('\n');
- }
- return report.toString();
+ public Builder samplingFrequency(int samplingFrequency) {
+ this.samplingFrequency = samplingFrequency;
+ return this;
}
- private void addMeasurement() {
- // printMeasures();
- List<Measurement> measurements = new ArrayList<>();
- measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount * this.samplingFrequency));
-
- addClusteringPerformanceMeasurements(measurements);
- Measurement[] finalMeasurements = measurements.toArray(new Measurement[measurements.size()]);
-
- LearningEvaluation learningEvaluation = new LearningEvaluation(finalMeasurements);
- learningCurve.insertEntry(learningEvaluation);
- logger.debug("evaluator id = {}", this.id);
- // logger.info(learningEvaluation.toString());
-
- if (immediateResultStream != null) {
- if (firstDump) {
- immediateResultStream.println(learningCurve.headerToString());
- firstDump = false;
- }
-
- immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1));
- immediateResultStream.flush();
- }
+ public Builder decayHorizon(int decayHorizon) {
+ this.decayHorizon = decayHorizon;
+ return this;
}
- private void addClusteringPerformanceMeasurements(List<Measurement> measurements) {
- for (MeasureCollection measure : measures) {
- for (int j = 0; j < measure.getNumMeasures(); j++) {
- Measurement measurement = new Measurement(measure.getName(j), measure.getLastValue(j));
- measurements.add(measurement);
- }
- }
+ public Builder dumpFile(File file) {
+ this.dumpFile = file;
+ return this;
}
- private void concludeMeasurement() {
- logger.info("last event is received!");
- logger.info("total count: {}", this.totalCount);
-
- String learningCurveSummary = this.toString();
- logger.info(learningCurveSummary);
-
- long experimentEnd = System.nanoTime();
- long totalExperimentTime = TimeUnit.SECONDS.convert(experimentEnd - experimentStart, TimeUnit.NANOSECONDS);
- logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount);
- // logger.info("average throughput rate: {} instances/seconds", (totalCount/totalExperimentTime));
+ public ClusteringEvaluatorProcessor build() {
+ return new ClusteringEvaluatorProcessor(this);
}
-
- private void printMeasures() {
- StringBuilder sb = new StringBuilder();
- for (MeasureCollection measure : measures) {
-
- sb.append("Mean ").append(measure.getClass().getSimpleName()).append(":").append(measure.getNumMeasures()).append("\n");
- for (int j = 0; j < measure.getNumMeasures(); j++) {
- sb.append("[").append(measure.getName(j)).append("=").append(measure.getLastValue(j)).append("] \n");
-
- }
- sb.append("\n");
- }
-
- logger.debug("\n MEASURES: \n\n {}", sb.toString());
- System.out.println(sb.toString());
- }
-
- public static class Builder {
-
- private int samplingFrequency = 1000;
- private File dumpFile = null;
- private int decayHorizon = 1000;
-
- public Builder(int samplingFrequency) {
- this.samplingFrequency = samplingFrequency;
- }
-
- public Builder(ClusteringEvaluatorProcessor oldProcessor) {
- this.samplingFrequency = oldProcessor.samplingFrequency;
- this.dumpFile = oldProcessor.dumpFile;
- this.decayHorizon = oldProcessor.decayHorizon;
- }
-
- public Builder samplingFrequency(int samplingFrequency) {
- this.samplingFrequency = samplingFrequency;
- return this;
- }
-
- public Builder decayHorizon(int decayHorizon) {
- this.decayHorizon = decayHorizon;
- return this;
- }
-
- public Builder dumpFile(File file) {
- this.dumpFile = file;
- return this;
- }
-
- public ClusteringEvaluatorProcessor build() {
- return new ClusteringEvaluatorProcessor(this);
- }
- }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringResultContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringResultContentEvent.java
index 1a5610e..95349c7 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringResultContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/ClusteringResultContentEvent.java
@@ -31,43 +31,45 @@
*/
final public class ClusteringResultContentEvent implements ContentEvent {
- private static final long serialVersionUID = -7746983521296618922L;
- private Clustering clustering;
- private final boolean isLast;
- private String key = "0";
+ private static final long serialVersionUID = -7746983521296618922L;
+ private Clustering clustering;
+ private final boolean isLast;
+ private String key = "0";
- public ClusteringResultContentEvent() {
- this.isLast = false;
- }
+ public ClusteringResultContentEvent() {
+ this.isLast = false;
+ }
- public ClusteringResultContentEvent(boolean isLast) {
- this.isLast = isLast;
- }
+ public ClusteringResultContentEvent(boolean isLast) {
+ this.isLast = isLast;
+ }
- /**
- * Instantiates a new clustering result event.
- *
- * @param clustering the clustering result
- * @param isLast is the last result
- */
- public ClusteringResultContentEvent(Clustering clustering, boolean isLast) {
- this.clustering = clustering;
- this.isLast = isLast;
- }
+ /**
+ * Instantiates a new clustering result event.
+ *
+ * @param clustering
+ * the clustering result
+ * @param isLast
+ * is the last result
+ */
+ public ClusteringResultContentEvent(Clustering clustering, boolean isLast) {
+ this.clustering = clustering;
+ this.isLast = isLast;
+ }
- public String getKey() {
- return key;
- }
+ public String getKey() {
+ return key;
+ }
- public void setKey(String key) {
- this.key = key;
- }
+ public void setKey(String key) {
+ this.key = key;
+ }
- public boolean isLastEvent() {
- return this.isLast;
- }
+ public boolean isLastEvent() {
+ return this.isLast;
+ }
- public Clustering getClustering() {
- return this.clustering;
- }
+ public Clustering getClustering() {
+ return this.clustering;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/EvaluatorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/EvaluatorProcessor.java
index f110872..ed2207f 100755
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/EvaluatorProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/EvaluatorProcessor.java
@@ -42,193 +42,192 @@
public class EvaluatorProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -2778051819116753612L;
+ private static final long serialVersionUID = -2778051819116753612L;
- private static final Logger logger =
- LoggerFactory.getLogger(EvaluatorProcessor.class);
-
- private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances";
-
- private final PerformanceEvaluator evaluator;
- private final int samplingFrequency;
- private final File dumpFile;
- private transient PrintStream immediateResultStream = null;
- private transient boolean firstDump = true;
-
-
- private long totalCount = 0;
- private long experimentStart = 0;
-
- private long sampleStart = 0;
-
- private LearningCurve learningCurve;
- private int id;
+ private static final Logger logger =
+ LoggerFactory.getLogger(EvaluatorProcessor.class);
- private EvaluatorProcessor(Builder builder){
- this.evaluator = builder.evaluator;
- this.samplingFrequency = builder.samplingFrequency;
- this.dumpFile = builder.dumpFile;
- }
-
- @Override
- public boolean process(ContentEvent event) {
-
- ResultContentEvent result = (ResultContentEvent) event;
-
- if((totalCount > 0) && (totalCount % samplingFrequency) == 0){
- long sampleEnd = System.nanoTime();
- long sampleDuration = TimeUnit.SECONDS.convert(sampleEnd - sampleStart, TimeUnit.NANOSECONDS);
- sampleStart = sampleEnd;
-
- logger.info("{} seconds for {} instances", sampleDuration, samplingFrequency);
- this.addMeasurement();
- }
-
- if(result.isLastEvent()){
- this.concludeMeasurement();
- return true;
- }
-
- evaluator.addResult(result.getInstance(), result.getClassVotes());
- totalCount += 1;
-
- if(totalCount == 1){
- sampleStart = System.nanoTime();
- experimentStart = sampleStart;
- }
-
- return false;
- }
+ private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances";
- @Override
- public void onCreate(int id) {
- this.id = id;
- this.learningCurve = new LearningCurve(ORDERING_MEASUREMENT_NAME);
+ private final PerformanceEvaluator evaluator;
+ private final int samplingFrequency;
+ private final File dumpFile;
+ private transient PrintStream immediateResultStream = null;
+ private transient boolean firstDump = true;
- if (this.dumpFile != null) {
- try {
- if(dumpFile.exists()){
- this.immediateResultStream = new PrintStream(
- new FileOutputStream(dumpFile, true), true);
- }else{
- this.immediateResultStream = new PrintStream(
- new FileOutputStream(dumpFile), true);
- }
-
- } catch (FileNotFoundException e) {
- this.immediateResultStream = null;
- logger.error("File not found exception for {}:{}", this.dumpFile.getAbsolutePath(), e.toString());
-
- } catch (Exception e){
- this.immediateResultStream = null;
- logger.error("Exception when creating {}:{}", this.dumpFile.getAbsolutePath(), e.toString());
- }
- }
-
- this.firstDump = true;
- }
+ private long totalCount = 0;
+ private long experimentStart = 0;
- @Override
- public Processor newProcessor(Processor p) {
- EvaluatorProcessor originalProcessor = (EvaluatorProcessor) p;
- EvaluatorProcessor newProcessor = new EvaluatorProcessor.Builder(originalProcessor).build();
-
- if (originalProcessor.learningCurve != null){
- newProcessor.learningCurve = originalProcessor.learningCurve;
- }
-
- return newProcessor;
- }
-
- @Override
- public String toString() {
- StringBuilder report = new StringBuilder();
-
- report.append(EvaluatorProcessor.class.getCanonicalName());
- report.append("id = ").append(this.id);
- report.append('\n');
-
- if(learningCurve.numEntries() > 0){
- report.append(learningCurve.toString());
- report.append('\n');
- }
- return report.toString();
- }
-
- private void addMeasurement(){
- List<Measurement> measurements = new Vector<>();
- measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount));
+ private long sampleStart = 0;
- Collections.addAll(measurements, evaluator.getPerformanceMeasurements());
-
- Measurement[] finalMeasurements = measurements.toArray(new Measurement[measurements.size()]);
-
- LearningEvaluation learningEvaluation = new LearningEvaluation(finalMeasurements);
- learningCurve.insertEntry(learningEvaluation);
- logger.debug("evaluator id = {}", this.id);
- logger.info(learningEvaluation.toString());
-
- if(immediateResultStream != null){
- if(firstDump){
- immediateResultStream.println(learningCurve.headerToString());
- firstDump = false;
- }
-
- immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() -1));
- immediateResultStream.flush();
- }
- }
-
- private void concludeMeasurement(){
- logger.info("last event is received!");
- logger.info("total count: {}", this.totalCount);
-
- String learningCurveSummary = this.toString();
- logger.info(learningCurveSummary);
+ private LearningCurve learningCurve;
+ private int id;
-
- long experimentEnd = System.nanoTime();
- long totalExperimentTime = TimeUnit.SECONDS.convert(experimentEnd - experimentStart, TimeUnit.NANOSECONDS);
- logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount);
+ private EvaluatorProcessor(Builder builder) {
+ this.evaluator = builder.evaluator;
+ this.samplingFrequency = builder.samplingFrequency;
+ this.dumpFile = builder.dumpFile;
+ }
- if (immediateResultStream!=null) {
- immediateResultStream.println("# COMPLETED");
- immediateResultStream.flush();
- }
- //logger.info("average throughput rate: {} instances/seconds", (totalCount/totalExperimentTime));
- }
-
- public static class Builder{
-
- private final PerformanceEvaluator evaluator;
- private int samplingFrequency = 100000;
- private File dumpFile = null;
-
- public Builder(PerformanceEvaluator evaluator){
- this.evaluator = evaluator;
- }
-
- public Builder(EvaluatorProcessor oldProcessor){
- this.evaluator = oldProcessor.evaluator;
- this.samplingFrequency = oldProcessor.samplingFrequency;
- this.dumpFile = oldProcessor.dumpFile;
- }
-
- public Builder samplingFrequency(int samplingFrequency){
- this.samplingFrequency = samplingFrequency;
- return this;
- }
-
- public Builder dumpFile(File file){
- this.dumpFile = file;
- return this;
- }
-
- public EvaluatorProcessor build(){
- return new EvaluatorProcessor(this);
- }
- }
+ @Override
+ public boolean process(ContentEvent event) {
+
+ ResultContentEvent result = (ResultContentEvent) event;
+
+ if ((totalCount > 0) && (totalCount % samplingFrequency) == 0) {
+ long sampleEnd = System.nanoTime();
+ long sampleDuration = TimeUnit.SECONDS.convert(sampleEnd - sampleStart, TimeUnit.NANOSECONDS);
+ sampleStart = sampleEnd;
+
+ logger.info("{} seconds for {} instances", sampleDuration, samplingFrequency);
+ this.addMeasurement();
+ }
+
+ if (result.isLastEvent()) {
+ this.concludeMeasurement();
+ return true;
+ }
+
+ evaluator.addResult(result.getInstance(), result.getClassVotes());
+ totalCount += 1;
+
+ if (totalCount == 1) {
+ sampleStart = System.nanoTime();
+ experimentStart = sampleStart;
+ }
+
+ return false;
+ }
+
+ @Override
+ public void onCreate(int id) {
+ this.id = id;
+ this.learningCurve = new LearningCurve(ORDERING_MEASUREMENT_NAME);
+
+ if (this.dumpFile != null) {
+ try {
+ if (dumpFile.exists()) {
+ this.immediateResultStream = new PrintStream(
+ new FileOutputStream(dumpFile, true), true);
+ } else {
+ this.immediateResultStream = new PrintStream(
+ new FileOutputStream(dumpFile), true);
+ }
+
+ } catch (FileNotFoundException e) {
+ this.immediateResultStream = null;
+ logger.error("File not found exception for {}:{}", this.dumpFile.getAbsolutePath(), e.toString());
+
+ } catch (Exception e) {
+ this.immediateResultStream = null;
+ logger.error("Exception when creating {}:{}", this.dumpFile.getAbsolutePath(), e.toString());
+ }
+ }
+
+ this.firstDump = true;
+ }
+
+ @Override
+ public Processor newProcessor(Processor p) {
+ EvaluatorProcessor originalProcessor = (EvaluatorProcessor) p;
+ EvaluatorProcessor newProcessor = new EvaluatorProcessor.Builder(originalProcessor).build();
+
+ if (originalProcessor.learningCurve != null) {
+ newProcessor.learningCurve = originalProcessor.learningCurve;
+ }
+
+ return newProcessor;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder report = new StringBuilder();
+
+ report.append(EvaluatorProcessor.class.getCanonicalName());
+ report.append("id = ").append(this.id);
+ report.append('\n');
+
+ if (learningCurve.numEntries() > 0) {
+ report.append(learningCurve.toString());
+ report.append('\n');
+ }
+ return report.toString();
+ }
+
+ private void addMeasurement() {
+ List<Measurement> measurements = new Vector<>();
+ measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount));
+
+ Collections.addAll(measurements, evaluator.getPerformanceMeasurements());
+
+ Measurement[] finalMeasurements = measurements.toArray(new Measurement[measurements.size()]);
+
+ LearningEvaluation learningEvaluation = new LearningEvaluation(finalMeasurements);
+ learningCurve.insertEntry(learningEvaluation);
+ logger.debug("evaluator id = {}", this.id);
+ logger.info(learningEvaluation.toString());
+
+ if (immediateResultStream != null) {
+ if (firstDump) {
+ immediateResultStream.println(learningCurve.headerToString());
+ firstDump = false;
+ }
+
+ immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1));
+ immediateResultStream.flush();
+ }
+ }
+
+ private void concludeMeasurement() {
+ logger.info("last event is received!");
+ logger.info("total count: {}", this.totalCount);
+
+ String learningCurveSummary = this.toString();
+ logger.info(learningCurveSummary);
+
+ long experimentEnd = System.nanoTime();
+ long totalExperimentTime = TimeUnit.SECONDS.convert(experimentEnd - experimentStart, TimeUnit.NANOSECONDS);
+ logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount);
+
+ if (immediateResultStream != null) {
+ immediateResultStream.println("# COMPLETED");
+ immediateResultStream.flush();
+ }
+ // logger.info("average throughput rate: {} instances/seconds",
+ // (totalCount/totalExperimentTime));
+ }
+
+ public static class Builder {
+
+ private final PerformanceEvaluator evaluator;
+ private int samplingFrequency = 100000;
+ private File dumpFile = null;
+
+ public Builder(PerformanceEvaluator evaluator) {
+ this.evaluator = evaluator;
+ }
+
+ public Builder(EvaluatorProcessor oldProcessor) {
+ this.evaluator = oldProcessor.evaluator;
+ this.samplingFrequency = oldProcessor.samplingFrequency;
+ this.dumpFile = oldProcessor.dumpFile;
+ }
+
+ public Builder samplingFrequency(int samplingFrequency) {
+ this.samplingFrequency = samplingFrequency;
+ return this;
+ }
+
+ public Builder dumpFile(File file) {
+ this.dumpFile = file;
+ return this;
+ }
+
+ public EvaluatorProcessor build() {
+ return new EvaluatorProcessor(this);
+ }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/PerformanceEvaluator.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/PerformanceEvaluator.java
index b88e87a..8f81392 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/PerformanceEvaluator.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/PerformanceEvaluator.java
@@ -34,29 +34,29 @@
*/
public interface PerformanceEvaluator extends MOAObject {
- /**
- * Resets this evaluator. It must be similar to starting a new evaluator
- * from scratch.
- *
- */
- public void reset();
+ /**
+ * Resets this evaluator. It must be similar to starting a new evaluator from
+ * scratch.
+ *
+ */
+ public void reset();
- /**
- * Adds a learning result to this evaluator.
- *
- * @param inst
- * the instance to be classified
- * @param classVotes
- * an array containing the estimated membership probabilities of
- * the test instance in each class
- * @return an array of measurements monitored in this evaluator
- */
- public void addResult(Instance inst, double[] classVotes);
+ /**
+ * Adds a learning result to this evaluator.
+ *
+ * @param inst
+ * the instance to be classified
+ * @param classVotes
+ * an array containing the estimated membership probabilities of the
+ * test instance in each class
+ * @return an array of measurements monitored in this evaluator
+ */
+ public void addResult(Instance inst, double[] classVotes);
- /**
- * Gets the current measurements monitored by this evaluator.
- *
- * @return an array of measurements monitored by this evaluator
- */
- public Measurement[] getPerformanceMeasurements();
+ /**
+ * Gets the current measurements monitored by this evaluator.
+ *
+ * @return an array of measurements monitored by this evaluator
+ */
+ public Measurement[] getPerformanceMeasurements();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/WindowClassificationPerformanceEvaluator.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/WindowClassificationPerformanceEvaluator.java
index 8b1f394..c1758c9 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/WindowClassificationPerformanceEvaluator.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/WindowClassificationPerformanceEvaluator.java
@@ -29,192 +29,191 @@
/**
* Classification evaluator that updates evaluation results using a sliding
* window.
- *
+ *
* @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
* @version $Revision: 7 $
*/
public class WindowClassificationPerformanceEvaluator extends AbstractMOAObject implements
- ClassificationPerformanceEvaluator {
+ ClassificationPerformanceEvaluator {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public IntOption widthOption = new IntOption("width",
- 'w', "Size of Window", 1000);
+ public IntOption widthOption = new IntOption("width",
+ 'w', "Size of Window", 1000);
- protected double TotalweightObserved = 0;
+ protected double TotalweightObserved = 0;
- protected Estimator weightObserved;
+ protected Estimator weightObserved;
- protected Estimator weightCorrect;
+ protected Estimator weightCorrect;
- protected Estimator weightCorrectNoChangeClassifier;
+ protected Estimator weightCorrectNoChangeClassifier;
- protected double lastSeenClass;
+ protected double lastSeenClass;
- protected Estimator[] columnKappa;
+ protected Estimator[] columnKappa;
- protected Estimator[] rowKappa;
+ protected Estimator[] rowKappa;
- protected Estimator[] classAccuracy;
+ protected Estimator[] classAccuracy;
- protected int numClasses;
+ protected int numClasses;
- public class Estimator {
+ public class Estimator {
- protected double[] window;
+ protected double[] window;
- protected int posWindow;
+ protected int posWindow;
- protected int lenWindow;
+ protected int lenWindow;
- protected int SizeWindow;
+ protected int SizeWindow;
- protected double sum;
+ protected double sum;
- public Estimator(int sizeWindow) {
- window = new double[sizeWindow];
- SizeWindow = sizeWindow;
- posWindow = 0;
- lenWindow = 0;
- }
-
- public void add(double value) {
- sum -= window[posWindow];
- sum += value;
- window[posWindow] = value;
- posWindow++;
- if (posWindow == SizeWindow) {
- posWindow = 0;
- }
- if (lenWindow < SizeWindow) {
- lenWindow++;
- }
- }
-
- public double total() {
- return sum;
- }
-
- public double length() {
- return lenWindow;
- }
-
+ public Estimator(int sizeWindow) {
+ window = new double[sizeWindow];
+ SizeWindow = sizeWindow;
+ posWindow = 0;
+ lenWindow = 0;
}
- /* public void setWindowWidth(int w) {
- this.width = w;
- reset();
- }*/
- @Override
- public void reset() {
- reset(this.numClasses);
+ public void add(double value) {
+ sum -= window[posWindow];
+ sum += value;
+ window[posWindow] = value;
+ posWindow++;
+ if (posWindow == SizeWindow) {
+ posWindow = 0;
+ }
+ if (lenWindow < SizeWindow) {
+ lenWindow++;
+ }
}
- public void reset(int numClasses) {
- this.numClasses = numClasses;
- this.rowKappa = new Estimator[numClasses];
- this.columnKappa = new Estimator[numClasses];
- this.classAccuracy = new Estimator[numClasses];
- for (int i = 0; i < this.numClasses; i++) {
- this.rowKappa[i] = new Estimator(this.widthOption.getValue());
- this.columnKappa[i] = new Estimator(this.widthOption.getValue());
- this.classAccuracy[i] = new Estimator(this.widthOption.getValue());
- }
- this.weightCorrect = new Estimator(this.widthOption.getValue());
- this.weightCorrectNoChangeClassifier = new Estimator(this.widthOption.getValue());
- this.weightObserved = new Estimator(this.widthOption.getValue());
- this.TotalweightObserved = 0;
- this.lastSeenClass = 0;
+ public double total() {
+ return sum;
}
- @Override
- public void addResult(Instance inst, double[] classVotes) {
- double weight = inst.weight();
- int trueClass = (int) inst.classValue();
- if (weight > 0.0) {
- if (TotalweightObserved == 0) {
- reset(inst.numClasses());
- }
- this.TotalweightObserved += weight;
- this.weightObserved.add(weight);
- int predictedClass = Utils.maxIndex(classVotes);
- if (predictedClass == trueClass) {
- this.weightCorrect.add(weight);
- } else {
- this.weightCorrect.add(0);
- }
- //Add Kappa statistic information
- for (int i = 0; i < this.numClasses; i++) {
- this.rowKappa[i].add(i == predictedClass ? weight : 0);
- this.columnKappa[i].add(i == trueClass ? weight : 0);
- }
- if (this.lastSeenClass == trueClass) {
- this.weightCorrectNoChangeClassifier.add(weight);
- } else {
- this.weightCorrectNoChangeClassifier.add(0);
- }
- this.classAccuracy[trueClass].add(predictedClass == trueClass ? weight : 0.0);
- this.lastSeenClass = trueClass;
- }
+ public double length() {
+ return lenWindow;
}
- @Override
- public Measurement[] getPerformanceMeasurements() {
- return new Measurement[]{
- new Measurement("classified instances",
+ }
+
+ /*
+ * public void setWindowWidth(int w) { this.width = w; reset(); }
+ */
+ @Override
+ public void reset() {
+ reset(this.numClasses);
+ }
+
+ public void reset(int numClasses) {
+ this.numClasses = numClasses;
+ this.rowKappa = new Estimator[numClasses];
+ this.columnKappa = new Estimator[numClasses];
+ this.classAccuracy = new Estimator[numClasses];
+ for (int i = 0; i < this.numClasses; i++) {
+ this.rowKappa[i] = new Estimator(this.widthOption.getValue());
+ this.columnKappa[i] = new Estimator(this.widthOption.getValue());
+ this.classAccuracy[i] = new Estimator(this.widthOption.getValue());
+ }
+ this.weightCorrect = new Estimator(this.widthOption.getValue());
+ this.weightCorrectNoChangeClassifier = new Estimator(this.widthOption.getValue());
+ this.weightObserved = new Estimator(this.widthOption.getValue());
+ this.TotalweightObserved = 0;
+ this.lastSeenClass = 0;
+ }
+
+ @Override
+ public void addResult(Instance inst, double[] classVotes) {
+ double weight = inst.weight();
+ int trueClass = (int) inst.classValue();
+ if (weight > 0.0) {
+ if (TotalweightObserved == 0) {
+ reset(inst.numClasses());
+ }
+ this.TotalweightObserved += weight;
+ this.weightObserved.add(weight);
+ int predictedClass = Utils.maxIndex(classVotes);
+ if (predictedClass == trueClass) {
+ this.weightCorrect.add(weight);
+ } else {
+ this.weightCorrect.add(0);
+ }
+ // Add Kappa statistic information
+ for (int i = 0; i < this.numClasses; i++) {
+ this.rowKappa[i].add(i == predictedClass ? weight : 0);
+ this.columnKappa[i].add(i == trueClass ? weight : 0);
+ }
+ if (this.lastSeenClass == trueClass) {
+ this.weightCorrectNoChangeClassifier.add(weight);
+ } else {
+ this.weightCorrectNoChangeClassifier.add(0);
+ }
+ this.classAccuracy[trueClass].add(predictedClass == trueClass ? weight : 0.0);
+ this.lastSeenClass = trueClass;
+ }
+ }
+
+ @Override
+ public Measurement[] getPerformanceMeasurements() {
+ return new Measurement[] {
+ new Measurement("classified instances",
this.TotalweightObserved),
- new Measurement("classifications correct (percent)",
+ new Measurement("classifications correct (percent)",
getFractionCorrectlyClassified() * 100.0),
- new Measurement("Kappa Statistic (percent)",
+ new Measurement("Kappa Statistic (percent)",
getKappaStatistic() * 100.0),
- new Measurement("Kappa Temporal Statistic (percent)",
+ new Measurement("Kappa Temporal Statistic (percent)",
getKappaTemporalStatistic() * 100.0)
- };
+ };
+ }
+
+ public double getTotalWeightObserved() {
+ return this.weightObserved.total();
+ }
+
+ public double getFractionCorrectlyClassified() {
+ return this.weightObserved.total() > 0.0 ? this.weightCorrect.total()
+ / this.weightObserved.total() : 0.0;
+ }
+
+ public double getKappaStatistic() {
+ if (this.weightObserved.total() > 0.0) {
+ double p0 = this.weightCorrect.total() / this.weightObserved.total();
+ double pc = 0;
+ for (int i = 0; i < this.numClasses; i++) {
+ pc += (this.rowKappa[i].total() / this.weightObserved.total())
+ * (this.columnKappa[i].total() / this.weightObserved.total());
+ }
+ return (p0 - pc) / (1 - pc);
+ } else {
+ return 0;
}
+ }
- public double getTotalWeightObserved() {
- return this.weightObserved.total();
+ public double getKappaTemporalStatistic() {
+ if (this.weightObserved.total() > 0.0) {
+ double p0 = this.weightCorrect.total() / this.weightObserved.total();
+ double pc = this.weightCorrectNoChangeClassifier.total() / this.weightObserved.total();
+
+ return (p0 - pc) / (1 - pc);
+ } else {
+ return 0;
}
+ }
- public double getFractionCorrectlyClassified() {
- return this.weightObserved.total() > 0.0 ? this.weightCorrect.total()
- / this.weightObserved.total() : 0.0;
- }
+ public double getFractionIncorrectlyClassified() {
+ return 1.0 - getFractionCorrectlyClassified();
+ }
- public double getKappaStatistic() {
- if (this.weightObserved.total() > 0.0) {
- double p0 = this.weightCorrect.total() / this.weightObserved.total();
- double pc = 0;
- for (int i = 0; i < this.numClasses; i++) {
- pc += (this.rowKappa[i].total() / this.weightObserved.total())
- * (this.columnKappa[i].total() / this.weightObserved.total());
- }
- return (p0 - pc) / (1 - pc);
- } else {
- return 0;
- }
- }
-
- public double getKappaTemporalStatistic() {
- if (this.weightObserved.total() > 0.0) {
- double p0 = this.weightCorrect.total() / this.weightObserved.total();
- double pc = this.weightCorrectNoChangeClassifier.total() / this.weightObserved.total();
-
- return (p0 - pc) / (1 - pc);
- } else {
- return 0;
- }
- }
-
- public double getFractionIncorrectlyClassified() {
- return 1.0 - getFractionCorrectlyClassified();
- }
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
- sb, indent);
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
+ sb, indent);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/CMM.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/CMM.java
index 1a41f6b..568f7c5 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/CMM.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/CMM.java
@@ -33,477 +33,491 @@
*
* CMM: Main class
*
- * Reference: Kremer et al., "An Effective Evaluation Measure for Clustering on Evolving Data Streams", KDD, 2011
+ * Reference: Kremer et al.,
+ * "An Effective Evaluation Measure for Clustering on Evolving Data Streams",
+ * KDD, 2011
*
- * @author Timm jansen
- * Data Management and Data Exploration Group, RWTH Aachen University
-*/
+ * @author Timm jansen Data Management and Data Exploration Group, RWTH Aachen
+ * University
+ */
-public class CMM extends MeasureCollection{
-
- private static final long serialVersionUID = 1L;
+public class CMM extends MeasureCollection {
- /**
- * found clustering
- */
- private Clustering clustering;
+ private static final long serialVersionUID = 1L;
- /**
- * the ground truth analysis
- */
- private CMM_GTAnalysis gtAnalysis;
+ /**
+ * found clustering
+ */
+ private Clustering clustering;
+
+ /**
+ * the ground truth analysis
+ */
+ private CMM_GTAnalysis gtAnalysis;
+
+ /**
+ * number of points within the horizon
+ */
+ private int numPoints;
+
+ /**
+ * number of clusters in the found clustering
+ */
+ private int numFClusters;
+
+ /**
+ * number of cluster in the adjusted groundtruth clustering that was
+ * calculated through the groundtruth analysis
+ */
+ private int numGT0Classes;
+
+ /**
+ * match found clusters to GT clusters
+ */
+ private int matchMap[];
+
+ /**
+ * pointInclusionProbFC[p][C] contains the probability of point p being
+ * included in cluster C
+ */
+ private double[][] pointInclusionProbFC;
+
+ /**
+ * threshold that defines when a point is being considered belonging to a
+ * cluster
+ */
+ private double pointInclusionProbThreshold = 0.5;
+
+ /**
+ * parameterize the error weight of missed points (default 1)
+ */
+ private double lamdaMissed = 1;
+
+ /**
+ * enable/disable debug mode
+ */
+ public boolean debug = false;
+
+ /**
+ * enable/disable class merge (main feature of ground truth analysis)
+ */
+ public boolean enableClassMerge = true;
+
+ /**
+ * enable/disable model error when enabled errors that are caused by the
+ * underling cluster model will not be counted
+ */
+ public boolean enableModelError = true;
+
+ @Override
+ protected String[] getNames() {
+ String[] names = { "CMM", "CMM Basic", "CMM Missed", "CMM Misplaced", "CMM Noise",
+ "CA Seperability", "CA Noise", "CA Modell" };
+ return names;
+ }
+
+ @Override
+ protected boolean[] getDefaultEnabled() {
+ boolean[] defaults = { false, false, false, false, false, false, false, false };
+ return defaults;
+ }
+
+ @Override
+ public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points)
+ throws Exception {
+ this.clustering = clustering;
+
+ numPoints = points.size();
+ numFClusters = clustering.size();
+
+ gtAnalysis = new CMM_GTAnalysis(trueClustering, points, enableClassMerge);
+
+ numGT0Classes = gtAnalysis.getNumberOfGT0Classes();
+
+ addValue("CA Seperability", gtAnalysis.getClassSeparability());
+ addValue("CA Noise", gtAnalysis.getNoiseSeparability());
+ addValue("CA Modell", gtAnalysis.getModelQuality());
+
+ /* init the matching and point distances */
+ calculateMatching();
+
+ /* calculate the actual error */
+ calculateError();
+ }
+
+ /**
+ * calculates the CMM specific matching between found clusters and ground
+ * truth clusters
+ */
+ private void calculateMatching() {
/**
- * number of points within the horizon
+ * found cluster frequencies
*/
- private int numPoints;
-
- /**
- * number of clusters in the found clustering
- */
- private int numFClusters;
-
- /**
- * number of cluster in the adjusted groundtruth clustering that
- * was calculated through the groundtruth analysis
- */
- private int numGT0Classes;
+ int[][] mapFC = new int[numFClusters][numGT0Classes];
/**
- * match found clusters to GT clusters
+ * ground truth cluster frequencies
*/
- private int matchMap[];
-
- /**
- * pointInclusionProbFC[p][C] contains the probability of point p
- * being included in cluster C
- */
- private double[][] pointInclusionProbFC;
-
- /**
- * threshold that defines when a point is being considered belonging to a cluster
- */
- private double pointInclusionProbThreshold = 0.5;
-
- /**
- * parameterize the error weight of missed points (default 1)
- */
- private double lamdaMissed = 1;
+ int[][] mapGT = new int[numGT0Classes][numGT0Classes];
+ int[] sumsFC = new int[numFClusters];
-
- /**
- * enable/disable debug mode
- */
- public boolean debug = false;
+ // calculate fuzzy mapping from
+ pointInclusionProbFC = new double[numPoints][numFClusters];
+ for (int p = 0; p < numPoints; p++) {
+ CMMPoint cmdp = gtAnalysis.getPoint(p);
+ // found cluster frequencies
+ for (int fc = 0; fc < numFClusters; fc++) {
+ Cluster cl = clustering.get(fc);
+ pointInclusionProbFC[p][fc] = cl.getInclusionProbability(cmdp);
+ if (pointInclusionProbFC[p][fc] >= pointInclusionProbThreshold) {
+ // make sure we don't count points twice that are contained in two
+ // merged clusters
+ if (cmdp.isNoise())
+ continue;
+ mapFC[fc][cmdp.workclass()]++;
+ sumsFC[fc]++;
+ }
+ }
-
- /**
- * enable/disable class merge (main feature of ground truth analysis)
- */
- public boolean enableClassMerge = true;
-
- /**
- * enable/disable model error
- * when enabled errors that are caused by the underling cluster model will not be counted
- */
- public boolean enableModelError = true;
-
-
- @Override
- protected String[] getNames() {
- String[] names = {"CMM","CMM Basic","CMM Missed","CMM Misplaced","CMM Noise",
- "CA Seperability", "CA Noise", "CA Modell"};
- return names;
+ // ground truth cluster frequencies
+ if (!cmdp.isNoise()) {
+ for (int hc = 0; hc < numGT0Classes; hc++) {
+ if (hc == cmdp.workclass()) {
+ mapGT[hc][hc]++;
+ }
+ else {
+ if (gtAnalysis.getGT0Cluster(hc).getInclusionProbability(cmdp) >= 1) {
+ mapGT[hc][cmdp.workclass()]++;
+ }
+ }
+ }
+ }
}
- @Override
- protected boolean[] getDefaultEnabled() {
- boolean [] defaults = {false, false, false, false, false, false, false, false};
- return defaults;
+ // assign each found cluster to a hidden cluster
+ matchMap = new int[numFClusters];
+ for (int fc = 0; fc < numFClusters; fc++) {
+ int matchIndex = -1;
+ // check if we only have one entry anyway
+ for (int hc0 = 0; hc0 < numGT0Classes; hc0++) {
+ if (mapFC[fc][hc0] != 0) {
+ if (matchIndex == -1)
+ matchIndex = hc0;
+ else {
+ matchIndex = -1;
+ break;
+ }
+ }
+ }
+
+ // more then one entry, so look for most similar frequency profile
+ int minDiff = Integer.MAX_VALUE;
+ if (sumsFC[fc] != 0 && matchIndex == -1) {
+ ArrayList<Integer> fitCandidates = new ArrayList<Integer>();
+ for (int hc0 = 0; hc0 < numGT0Classes; hc0++) {
+ int errDiff = 0;
+ for (int hc1 = 0; hc1 < numGT0Classes; hc1++) {
+ // fc profile doesn't fit into current hc profile
+ double freq_diff = mapFC[fc][hc1] - mapGT[hc0][hc1];
+ if (freq_diff > 0) {
+ errDiff += freq_diff;
+ }
+ }
+ if (errDiff == 0) {
+ fitCandidates.add(hc0);
+ }
+ if (errDiff < minDiff) {
+ minDiff = errDiff;
+ matchIndex = hc0;
+ }
+ if (debug) {
+ // System.out.println("FC"+fc+"("+Arrays.toString(mapFC[fc])+") - HC0_"+hc0+"("+Arrays.toString(mapGT[hc0])+"):"+errDiff);
+ }
+ }
+ // if we have a fitting profile overwrite the min error choice
+ // if we have multiple fit candidates, use majority vote of
+ // corresponding classes
+ if (fitCandidates.size() != 0) {
+ int bestGTfit = fitCandidates.get(0);
+ for (int i = 1; i < fitCandidates.size(); i++) {
+ int GTfit = fitCandidates.get(i);
+ if (mapFC[fc][GTfit] > mapFC[fc][bestGTfit])
+ bestGTfit = fitCandidates.get(i);
+ }
+ matchIndex = bestGTfit;
+ }
+ }
+
+ matchMap[fc] = matchIndex;
+ int realMatch = -1;
+ if (matchIndex == -1) {
+ if (debug)
+ System.out.println("No cluster match: needs to be implemented?");
+ }
+ else {
+ realMatch = gtAnalysis.getGT0Cluster(matchMap[fc]).getLabel();
+ }
+ clustering.get(fc).setMeasureValue("CMM Match", "C" + realMatch);
+ clustering.get(fc).setMeasureValue("CMM Workclass", "C" + matchMap[fc]);
}
-
- @Override
- public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) throws Exception{
- this.clustering = clustering;
-
- numPoints = points.size();
- numFClusters = clustering.size();
-
- gtAnalysis = new CMM_GTAnalysis(trueClustering, points, enableClassMerge);
-
- numGT0Classes = gtAnalysis.getNumberOfGT0Classes();
-
- addValue("CA Seperability",gtAnalysis.getClassSeparability());
- addValue("CA Noise",gtAnalysis.getNoiseSeparability());
- addValue("CA Modell",gtAnalysis.getModelQuality());
-
- /* init the matching and point distances */
- calculateMatching();
-
- /* calculate the actual error */
- calculateError();
+ // print matching table
+ if (debug) {
+ for (int i = 0; i < numFClusters; i++) {
+ System.out.print("C" + ((int) clustering.get(i).getId()) + " N:" + ((int) clustering.get(i).getWeight())
+ + " | ");
+ for (int j = 0; j < numGT0Classes; j++) {
+ System.out.print(mapFC[i][j] + " ");
+ }
+ System.out.print(" = " + sumsFC[i] + " | ");
+ String match = "-";
+ if (matchMap[i] != -1) {
+ match = Integer.toString(gtAnalysis.getGT0Cluster(matchMap[i]).getLabel());
+ }
+ System.out.println(" --> " + match + "(work:" + matchMap[i] + ")");
+ }
}
+ }
-
+ /**
+ * Calculate the actual error values
+ */
+ private void calculateError() {
+ int totalErrorCount = 0;
+ int totalRedundancy = 0;
+ int trueCoverage = 0;
+ int totalCoverage = 0;
+
+ int numNoise = 0;
+ double errorNoise = 0;
+ double errorNoiseMax = 0;
+
+ double errorMissed = 0;
+ double errorMissedMax = 0;
+
+ double errorMisplaced = 0;
+ double errorMisplacedMax = 0;
+
+ double totalError = 0.0;
+ double totalErrorMax = 0.0;
+
/**
- * calculates the CMM specific matching between found clusters and ground truth clusters
+ * mainly iterate over all points and find the right error value for the
+ * point. within the same run calculate various other stuff like coverage
+ * etc...
*/
- private void calculateMatching(){
+ for (int p = 0; p < numPoints; p++) {
+ CMMPoint cmdp = gtAnalysis.getPoint(p);
+ double weight = cmdp.weight();
+ // noise counter
+ if (cmdp.isNoise()) {
+ numNoise++;
+ // this is always 1
+ errorNoiseMax += cmdp.connectivity * weight;
+ }
+ else {
+ errorMissedMax += cmdp.connectivity * weight;
+ errorMisplacedMax += cmdp.connectivity * weight;
+ }
+ // sum up maxError as the individual errors are the quality weighted
+ // between 0-1
+ totalErrorMax += cmdp.connectivity * weight;
- /**
- * found cluster frequencies
- */
- int[][] mapFC = new int[numFClusters][numGT0Classes];
+ double err = 0;
+ int coverage = 0;
- /**
- * ground truth cluster frequencies
- */
- int[][] mapGT = new int[numGT0Classes][numGT0Classes];
- int [] sumsFC = new int[numFClusters];
+ // check every FCluster
+ for (int c = 0; c < numFClusters; c++) {
+ // contained in cluster c?
+ if (pointInclusionProbFC[p][c] >= pointInclusionProbThreshold) {
+ coverage++;
- //calculate fuzzy mapping from
- pointInclusionProbFC = new double[numPoints][numFClusters];
- for (int p = 0; p < numPoints; p++) {
- CMMPoint cmdp = gtAnalysis.getPoint(p);
- //found cluster frequencies
- for (int fc = 0; fc < numFClusters; fc++) {
- Cluster cl = clustering.get(fc);
- pointInclusionProbFC[p][fc] = cl.getInclusionProbability(cmdp);
- if (pointInclusionProbFC[p][fc] >= pointInclusionProbThreshold) {
- //make sure we don't count points twice that are contained in two merged clusters
- if(cmdp.isNoise()) continue;
- mapFC[fc][cmdp.workclass()]++;
- sumsFC[fc]++;
- }
+ if (!cmdp.isNoise()) {
+ // PLACED CORRECTLY
+ if (matchMap[c] == cmdp.workclass()) {
}
-
- //ground truth cluster frequencies
- if(!cmdp.isNoise()){
- for(int hc = 0; hc < numGT0Classes;hc++){
- if(hc == cmdp.workclass()){
- mapGT[hc][hc]++;
- }
- else{
- if(gtAnalysis.getGT0Cluster(hc).getInclusionProbability(cmdp) >= 1){
- mapGT[hc][cmdp.workclass()]++;
- }
- }
- }
+ // MISPLACED
+ else {
+ double errvalue = misplacedError(cmdp, c);
+ if (errvalue > err)
+ err = errvalue;
}
+ }
+ else {
+ // NOISE
+ double errvalue = noiseError(cmdp, c);
+ if (errvalue > err)
+ err = errvalue;
+ }
}
-
- //assign each found cluster to a hidden cluster
- matchMap = new int[numFClusters];
- for (int fc = 0; fc < numFClusters; fc++) {
- int matchIndex = -1;
- //check if we only have one entry anyway
- for (int hc0 = 0; hc0 < numGT0Classes; hc0++) {
- if(mapFC[fc][hc0]!=0){
- if(matchIndex == -1)
- matchIndex = hc0;
- else{
- matchIndex = -1;
- break;
- }
- }
- }
-
- //more then one entry, so look for most similar frequency profile
- int minDiff = Integer.MAX_VALUE;
- if(sumsFC[fc]!=0 && matchIndex == -1){
- ArrayList<Integer> fitCandidates = new ArrayList<Integer>();
- for (int hc0 = 0; hc0 < numGT0Classes; hc0++) {
- int errDiff = 0;
- for (int hc1 = 0; hc1 < numGT0Classes; hc1++) {
- //fc profile doesn't fit into current hc profile
- double freq_diff = mapFC[fc][hc1] - mapGT[hc0][hc1];
- if(freq_diff > 0){
- errDiff+= freq_diff;
- }
- }
- if(errDiff == 0){
- fitCandidates.add(hc0);
- }
- if(errDiff < minDiff){
- minDiff = errDiff;
- matchIndex = hc0;
- }
- if(debug){
- //System.out.println("FC"+fc+"("+Arrays.toString(mapFC[fc])+") - HC0_"+hc0+"("+Arrays.toString(mapGT[hc0])+"):"+errDiff);
- }
- }
- //if we have a fitting profile overwrite the min error choice
- //if we have multiple fit candidates, use majority vote of corresponding classes
- if(fitCandidates.size()!=0){
- int bestGTfit = fitCandidates.get(0);
- for(int i = 1; i < fitCandidates.size(); i++){
- int GTfit = fitCandidates.get(i);
- if(mapFC[fc][GTfit] > mapFC[fc][bestGTfit])
- bestGTfit=fitCandidates.get(i);
- }
- matchIndex = bestGTfit;
- }
- }
-
- matchMap[fc] = matchIndex;
- int realMatch = -1;
- if(matchIndex==-1){
- if(debug)
- System.out.println("No cluster match: needs to be implemented?");
- }
- else{
- realMatch = gtAnalysis.getGT0Cluster(matchMap[fc]).getLabel();
- }
- clustering.get(fc).setMeasureValue("CMM Match", "C"+realMatch);
- clustering.get(fc).setMeasureValue("CMM Workclass", "C"+matchMap[fc]);
+ }
+ // not in any cluster
+ if (coverage == 0) {
+ // MISSED
+ if (!cmdp.isNoise()) {
+ err = missedError(cmdp, true);
+ errorMissed += weight * err;
}
-
- //print matching table
- if(debug){
- for (int i = 0; i < numFClusters; i++) {
- System.out.print("C"+((int)clustering.get(i).getId()) + " N:"+((int)clustering.get(i).getWeight())+" | ");
- for (int j = 0; j < numGT0Classes; j++) {
- System.out.print(mapFC[i][j] + " ");
- }
- System.out.print(" = "+sumsFC[i] + " | ");
- String match = "-";
- if (matchMap[i]!=-1) {
- match = Integer.toString(gtAnalysis.getGT0Cluster(matchMap[i]).getLabel());
- }
- System.out.println(" --> " + match + "(work:"+matchMap[i]+")");
- }
+ // NOISE
+ else {
}
- }
-
-
- /**
- * Calculate the actual error values
- */
- private void calculateError(){
- int totalErrorCount = 0;
- int totalRedundancy = 0;
- int trueCoverage = 0;
- int totalCoverage = 0;
-
- int numNoise = 0;
- double errorNoise = 0;
- double errorNoiseMax = 0;
-
- double errorMissed = 0;
- double errorMissedMax = 0;
-
- double errorMisplaced = 0;
- double errorMisplacedMax = 0;
-
- double totalError = 0.0;
- double totalErrorMax = 0.0;
-
- /** mainly iterate over all points and find the right error value for the point.
- * within the same run calculate various other stuff like coverage etc...
- */
- for (int p = 0; p < numPoints; p++) {
- CMMPoint cmdp = gtAnalysis.getPoint(p);
- double weight = cmdp.weight();
- //noise counter
- if(cmdp.isNoise()){
- numNoise++;
- //this is always 1
- errorNoiseMax+=cmdp.connectivity*weight;
- }
- else{
- errorMissedMax+=cmdp.connectivity*weight;
- errorMisplacedMax+=cmdp.connectivity*weight;
- }
- //sum up maxError as the individual errors are the quality weighted between 0-1
- totalErrorMax+=cmdp.connectivity*weight;
-
-
- double err = 0;
- int coverage = 0;
-
- //check every FCluster
- for (int c = 0; c < numFClusters; c++) {
- //contained in cluster c?
- if(pointInclusionProbFC[p][c] >= pointInclusionProbThreshold){
- coverage++;
-
- if(!cmdp.isNoise()){
- //PLACED CORRECTLY
- if(matchMap[c] == cmdp.workclass()){
- }
- //MISPLACED
- else{
- double errvalue = misplacedError(cmdp, c);
- if(errvalue > err)
- err = errvalue;
- }
- }
- else{
- //NOISE
- double errvalue = noiseError(cmdp, c);
- if(errvalue > err) err = errvalue;
- }
- }
- }
- //not in any cluster
- if(coverage == 0){
- //MISSED
- if(!cmdp.isNoise()){
- err = missedError(cmdp,true);
- errorMissed+= weight*err;
- }
- //NOISE
- else{
- }
- }
- else{
- if(!cmdp.isNoise()){
- errorMisplaced+= err*weight;
- }
- else{
- errorNoise+= err*weight;
- }
- }
-
- /* processing of other evaluation values */
- totalError+= err*weight;
- if(err!=0)totalErrorCount++;
- if(coverage>0) totalCoverage++; //points covered by clustering (incl. noise)
- if(coverage>0 && !cmdp.isNoise()) trueCoverage++; //points covered by clustering, don't count noise
- if(coverage>1) totalRedundancy++; //include noise
-
- cmdp.p.setMeasureValue("CMM",err);
- cmdp.p.setMeasureValue("Redundancy", coverage);
+ }
+ else {
+ if (!cmdp.isNoise()) {
+ errorMisplaced += err * weight;
}
-
- addValue("CMM", (totalErrorMax!=0)?1-totalError/totalErrorMax:1);
- addValue("CMM Missed", (errorMissedMax!=0)?1-errorMissed/errorMissedMax:1);
- addValue("CMM Misplaced", (errorMisplacedMax!=0)?1-errorMisplaced/errorMisplacedMax:1);
- addValue("CMM Noise", (errorNoiseMax!=0)?1-errorNoise/errorNoiseMax:1);
- addValue("CMM Basic", 1-((double)totalErrorCount/(double)numPoints));
-
- if(debug){
- System.out.println("-------------");
+ else {
+ errorNoise += err * weight;
}
+ }
+
+ /* processing of other evaluation values */
+ totalError += err * weight;
+ if (err != 0)
+ totalErrorCount++;
+ if (coverage > 0)
+ totalCoverage++; // points covered by clustering (incl. noise)
+ if (coverage > 0 && !cmdp.isNoise())
+ trueCoverage++; // points covered by clustering, don't count noise
+ if (coverage > 1)
+ totalRedundancy++; // include noise
+
+ cmdp.p.setMeasureValue("CMM", err);
+ cmdp.p.setMeasureValue("Redundancy", coverage);
}
+ addValue("CMM", (totalErrorMax != 0) ? 1 - totalError / totalErrorMax : 1);
+ addValue("CMM Missed", (errorMissedMax != 0) ? 1 - errorMissed / errorMissedMax : 1);
+ addValue("CMM Misplaced", (errorMisplacedMax != 0) ? 1 - errorMisplaced / errorMisplacedMax : 1);
+ addValue("CMM Noise", (errorNoiseMax != 0) ? 1 - errorNoise / errorNoiseMax : 1);
+ addValue("CMM Basic", 1 - ((double) totalErrorCount / (double) numPoints));
- private double noiseError(CMMPoint cmdp, int assignedClusterID){
- int gtAssignedID = matchMap[assignedClusterID];
- double error;
-
- //Cluster wasn't matched, so just contains noise
- //TODO: Noiscluster?
- //also happens when we decrease the radius and there is only a noise point in the center
- if(gtAssignedID==-1){
- error = 1;
- cmdp.p.setMeasureValue("CMM Type","noise - cluster");
- }
- else{
- if(enableModelError && gtAnalysis.getGT0Cluster(gtAssignedID).getInclusionProbability(cmdp) >= pointInclusionProbThreshold){
- //set to MIN_ERROR so we can still track the error
- error = 0.00001;
- cmdp.p.setMeasureValue("CMM Type","noise - byModel");
- }
- else{
- error = 1 - gtAnalysis.getConnectionValue(cmdp, gtAssignedID);
- cmdp.p.setMeasureValue("CMM Type","noise");
- }
- }
+ if (debug) {
+ System.out.println("-------------");
+ }
+ }
- return error;
+ private double noiseError(CMMPoint cmdp, int assignedClusterID) {
+ int gtAssignedID = matchMap[assignedClusterID];
+ double error;
+
+ // Cluster wasn't matched, so just contains noise
+ // TODO: Noiscluster?
+ // also happens when we decrease the radius and there is only a noise point
+ // in the center
+ if (gtAssignedID == -1) {
+ error = 1;
+ cmdp.p.setMeasureValue("CMM Type", "noise - cluster");
+ }
+ else {
+ if (enableModelError
+ && gtAnalysis.getGT0Cluster(gtAssignedID).getInclusionProbability(cmdp) >= pointInclusionProbThreshold) {
+ // set to MIN_ERROR so we can still track the error
+ error = 0.00001;
+ cmdp.p.setMeasureValue("CMM Type", "noise - byModel");
+ }
+ else {
+ error = 1 - gtAnalysis.getConnectionValue(cmdp, gtAssignedID);
+ cmdp.p.setMeasureValue("CMM Type", "noise");
+ }
}
- private double missedError(CMMPoint cmdp, boolean useHullDistance){
- cmdp.p.setMeasureValue("CMM Type","missed");
- if(!useHullDistance){
- return cmdp.connectivity;
+ return error;
+ }
+
+ private double missedError(CMMPoint cmdp, boolean useHullDistance) {
+ cmdp.p.setMeasureValue("CMM Type", "missed");
+ if (!useHullDistance) {
+ return cmdp.connectivity;
+ }
+ else {
+ // main idea: look at relative distance of missed point to cluster
+ double minHullDist = 1;
+ for (int fc = 0; fc < numFClusters; fc++) {
+ // if fc is mappend onto the class of the point, check it for its
+ // hulldist
+ if (matchMap[fc] != -1 && matchMap[fc] == cmdp.workclass()) {
+ if (clustering.get(fc) instanceof SphereCluster) {
+ SphereCluster sc = (SphereCluster) clustering.get(fc);
+ double distanceFC = sc.getCenterDistance(cmdp);
+ double radius = sc.getRadius();
+ double hullDist = (distanceFC - radius) / (distanceFC + radius);
+ if (hullDist < minHullDist)
+ minHullDist = hullDist;
+ }
+ else {
+ double min = 1;
+ double max = 1;
+
+ // TODO: distance for random shape
+ // generate X points from the cluster with
+ // clustering.get(fc).sample(null)
+ // and find Min and Max values
+
+ double hullDist = min / max;
+ if (hullDist < minHullDist)
+ minHullDist = hullDist;
+ }
}
- else{
- //main idea: look at relative distance of missed point to cluster
- double minHullDist = 1;
- for (int fc = 0; fc < numFClusters; fc++){
- //if fc is mappend onto the class of the point, check it for its hulldist
- if(matchMap[fc]!=-1 && matchMap[fc] == cmdp.workclass()){
- if(clustering.get(fc) instanceof SphereCluster){
- SphereCluster sc = (SphereCluster)clustering.get(fc);
- double distanceFC = sc.getCenterDistance(cmdp);
- double radius = sc.getRadius();
- double hullDist = (distanceFC-radius)/(distanceFC+radius);
- if(hullDist < minHullDist)
- minHullDist = hullDist;
- }
- else{
- double min = 1;
- double max = 1;
+ }
- //TODO: distance for random shape
- //generate X points from the cluster with clustering.get(fc).sample(null)
- //and find Min and Max values
+ // use distance as weight
+ if (minHullDist > 1)
+ minHullDist = 1;
- double hullDist = min/max;
- if(hullDist < minHullDist)
- minHullDist = hullDist;
- }
- }
- }
+ double weight = (1 - Math.exp(-lamdaMissed * minHullDist));
+ cmdp.p.setMeasureValue("HullDistWeight", weight);
- //use distance as weight
- if(minHullDist>1) minHullDist = 1;
+ return weight * cmdp.connectivity;
+ }
+ }
- double weight = (1-Math.exp(-lamdaMissed*minHullDist));
- cmdp.p.setMeasureValue("HullDistWeight",weight);
+ private double misplacedError(CMMPoint cmdp, int assignedClusterID) {
+ double weight = 0;
- return weight*cmdp.connectivity;
- }
+ int gtAssignedID = matchMap[assignedClusterID];
+ // TODO take care of noise cluster?
+ if (gtAssignedID == -1) {
+ System.out.println("Point " + cmdp.getTimestamp() + " from gtcluster " + cmdp.trueClass
+ + " assigned to noise cluster " + assignedClusterID);
+ return 1;
}
-
- private double misplacedError(CMMPoint cmdp, int assignedClusterID){
- double weight = 0;
-
- int gtAssignedID = matchMap[assignedClusterID];
- //TODO take care of noise cluster?
- if(gtAssignedID ==-1){
- System.out.println("Point "+cmdp.getTimestamp()+" from gtcluster "+cmdp.trueClass+" assigned to noise cluster "+assignedClusterID);
- return 1;
- }
-
- if(gtAssignedID == cmdp.workclass())
- return 0;
- else{
- //assigned and real GT0 cluster are not connected, but does the model have the
- //chance of separating this point after all?
- if(enableModelError && gtAnalysis.getGT0Cluster(gtAssignedID).getInclusionProbability(cmdp) >= pointInclusionProbThreshold){
- weight = 0;
- cmdp.p.setMeasureValue("CMM Type","missplaced - byModel");
- }
- else{
- //point was mapped onto wrong cluster (assigned), so check how far away
- //the nearest point is within the wrongly assigned cluster
- weight = 1 - gtAnalysis.getConnectionValue(cmdp, gtAssignedID);
- }
- }
- double err_value;
- //set to MIN_ERROR so we can still track the error
- if(weight == 0){
- err_value= 0.00001;
- }
- else{
- err_value = weight*cmdp.connectivity;
- cmdp.p.setMeasureValue("CMM Type","missplaced");
- }
-
- return err_value;
+ if (gtAssignedID == cmdp.workclass())
+ return 0;
+ else {
+ // assigned and real GT0 cluster are not connected, but does the model
+ // have the
+ // chance of separating this point after all?
+ if (enableModelError
+ && gtAnalysis.getGT0Cluster(gtAssignedID).getInclusionProbability(cmdp) >= pointInclusionProbThreshold) {
+ weight = 0;
+ cmdp.p.setMeasureValue("CMM Type", "missplaced - byModel");
+ }
+ else {
+ // point was mapped onto wrong cluster (assigned), so check how far away
+ // the nearest point is within the wrongly assigned cluster
+ weight = 1 - gtAnalysis.getConnectionValue(cmdp, gtAssignedID);
+ }
+ }
+ double err_value;
+ // set to MIN_ERROR so we can still track the error
+ if (weight == 0) {
+ err_value = 0.00001;
+ }
+ else {
+ err_value = weight * cmdp.connectivity;
+ cmdp.p.setMeasureValue("CMM Type", "missplaced");
}
- public String getParameterString(){
- String para = gtAnalysis.getParameterString();
- para+="lambdaMissed="+lamdaMissed+";";
- return para;
- }
+ return err_value;
+ }
+
+ public String getParameterString() {
+ String para = gtAnalysis.getParameterString();
+ para += "lambdaMissed=" + lamdaMissed + ";";
+ return para;
+ }
}
-
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/CMM_GTAnalysis.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/CMM_GTAnalysis.java
index 53fb4dc..c31fa74 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/CMM_GTAnalysis.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/CMM_GTAnalysis.java
@@ -33,811 +33,828 @@
*
* CMM: Ground truth analysis
*
- * Reference: Kremer et al., "An Effective Evaluation Measure for Clustering on Evolving Data Streams", KDD, 2011
+ * Reference: Kremer et al.,
+ * "An Effective Evaluation Measure for Clustering on Evolving Data Streams",
+ * KDD, 2011
*
- * @author Timm jansen
- * Data Management and Data Exploration Group, RWTH Aachen University
-*/
+ * @author Timm jansen Data Management and Data Exploration Group, RWTH Aachen
+ * University
+ */
/*
- * TODO:
- * - try to avoid calcualting the radius multiple times
- * - avoid the full distance map?
- * - knn functionality in clusters
- * - noise error
+ * TODO: - try to avoid calcualting the radius multiple times - avoid the full
+ * distance map? - knn functionality in clusters - noise error
*/
-public class CMM_GTAnalysis{
-
+public class CMM_GTAnalysis {
+
+ /**
+ * the given ground truth clustering
+ */
+ private Clustering gtClustering;
+
+ /**
+ * list of given points within the horizon
+ */
+ private ArrayList<CMMPoint> cmmpoints;
+
+ /**
+ * the newly calculate ground truth clustering
+ */
+ private ArrayList<GTCluster> gt0Clusters;
+
+ /**
+ * IDs of noise points
+ */
+ private ArrayList<Integer> noise;
+
+ /**
+ * total number of points
+ */
+ private int numPoints;
+
+ /**
+ * number of clusters of the original ground truth
+ */
+ private int numGTClusters;
+
+ /**
+ * number of classes of the original ground truth, in case of a micro
+ * clustering ground truth this differs from numGTClusters
+ */
+ private int numGTClasses;
+
+ /**
+ * number of classes after we are done with the analysis
+ */
+ private int numGT0Classes;
+
+ /**
+ * number of dimensions
+ */
+ private int numDims;
+
+ /**
+ * mapping between true cluster ID/class label of the original ground truth
+ * and the internal cluster ID/working class label.
+ *
+ * different original cluster IDs might map to the same new cluster ID due to
+ * merging of two clusters
+ */
+ private HashMap<Integer, Integer> mapTrueLabelToWorkLabel;
+
+ /**
+ * log of how clusters have been merged (for debugging)
+ */
+ private int[] mergeMap;
+
+ /**
+ * number of non-noise points that will create an error due to the underlying
+ * clustering model (e.g. point being covered by two clusters representing
+ * different classes)
+ */
+ private int noiseErrorByModel;
+
+ /**
+ * number of noise points that will create an error due to the underlying
+ * clustering model (e.g. noise point being covered by a cluster)
+ */
+ private int pointErrorByModel;
+
+ /**
+ * CMM debug mode
+ */
+ private boolean debug = false;
+
+ /******* CMM parameter ***********/
+
+ /**
+ * defines how many nearest neighbors will be used
+ */
+ private int knnNeighbourhood = 2;
+
+ /**
+ * the threshold which defines when ground truth clusters will be merged. set
+ * to 1 to disable merging
+ */
+ private double tauConnection = 0.5;
+
+ /**
+ * experimental (default: disabled) separate k for points to cluster and
+ * cluster to cluster
+ */
+ private double clusterConnectionMaxPoints = knnNeighbourhood;
+
+ /**
+ * experimental (default: disabled) use exponential connectivity function to
+ * model different behavior: closer points will have a stronger connection
+ * compared to the linear function. Use ConnRefXValue and ConnX to better
+ * parameterize lambda, which controls the decay of the connectivity
+ */
+ private boolean useExpConnectivity = false;
+ private double lambdaConnRefXValue = 0.01;
+ private double lambdaConnX = 4;
+ private double lamdaConn;
+
+ /******************************************/
+
+ /**
+ * Wrapper class for data points to store CMM relevant attributes
+ *
+ */
+ protected class CMMPoint extends DataPoint {
/**
- * the given ground truth clustering
+ * Reference to original point
*/
- private Clustering gtClustering;
-
- /**
- * list of given points within the horizon
- */
- private ArrayList<CMMPoint> cmmpoints;
-
- /**
- * the newly calculate ground truth clustering
- */
- private ArrayList<GTCluster> gt0Clusters;
+ protected DataPoint p = null;
/**
- * IDs of noise points
+ * point ID
*/
- private ArrayList<Integer> noise;
-
- /**
- * total number of points
- */
- private int numPoints;
+ protected int pID = 0;
/**
- * number of clusters of the original ground truth
+ * true class label
*/
- private int numGTClusters;
+ protected int trueClass = -1;
/**
- * number of classes of the original ground truth, in case of a
- * micro clustering ground truth this differs from numGTClusters
+ * the connectivity of the point to its cluster
*/
- private int numGTClasses;
+ protected double connectivity = 1.0;
/**
- * number of classes after we are done with the analysis
+ * knn distnace within own cluster
*/
- private int numGT0Classes;
+ protected double knnInCluster = 0.0;
/**
- * number of dimensions
+ * knn indices (for debugging only)
*/
- private int numDims;
+ protected ArrayList<Integer> knnIndices;
+
+ public CMMPoint(DataPoint point, int id) {
+ // make a copy, but keep reference
+ super(point, point.getTimestamp());
+ p = point;
+ pID = id;
+ trueClass = (int) point.classValue();
+ }
/**
- * mapping between true cluster ID/class label of the original ground truth
- * and the internal cluster ID/working class label.
+ * Retruns the current working label of the cluster the point belongs to.
+ * The label can change due to merging of clusters.
*
- * different original cluster IDs might map to the same new cluster ID due to merging of two clusters
+ * @return the current working class label
*/
- private HashMap<Integer, Integer> mapTrueLabelToWorkLabel;
+ protected int workclass() {
+ if (trueClass == -1)
+ return -1;
+ else
+ return mapTrueLabelToWorkLabel.get(trueClass);
+ }
+ }
+
+ /**
+ * Main class to model the new clusters that will be the output of the cluster
+ * analysis
+ *
+ */
+ protected class GTCluster {
+ /** points that are per definition in the cluster */
+ private ArrayList<Integer> points = new ArrayList<Integer>();
/**
- * log of how clusters have been merged (for debugging)
+ * a new GT cluster consists of one or more "old" GT clusters.
+ * Connected/overlapping clusters cannot be merged directly because of the
+ * underlying cluster model. E.g. for merging two spherical clusters the new
+ * cluster sphere can cover a lot more space then two separate smaller
+ * spheres. To keep the original coverage we need to keep the orignal
+ * clusters and merge them on an abstract level.
*/
- private int[] mergeMap;
+ private ArrayList<Integer> clusterRepresentations = new ArrayList<Integer>();
- /**
- * number of non-noise points that will create an error due to the underlying clustering model
- * (e.g. point being covered by two clusters representing different classes)
- */
- private int noiseErrorByModel;
+ /** current work class (changes when merging) */
+ private int workclass;
- /**
- * number of noise points that will create an error due to the underlying clustering model
- * (e.g. noise point being covered by a cluster)
- */
- private int pointErrorByModel;
-
- /**
- * CMM debug mode
- */
- private boolean debug = false;
+ /** original work class */
+ private final int orgWorkClass;
-
- /******* CMM parameter ***********/
+ /** original class label */
+ private final int label;
- /**
- * defines how many nearest neighbors will be used
- */
- private int knnNeighbourhood = 2;
+ /** clusters that have been merged into this cluster (debugging) */
+ private ArrayList<Integer> mergedWorkLabels = null;
- /**
- * the threshold which defines when ground truth clusters will be merged.
- * set to 1 to disable merging
- */
- private double tauConnection = 0.5;
-
- /**
- * experimental (default: disabled)
- * separate k for points to cluster and cluster to cluster
- */
- private double clusterConnectionMaxPoints = knnNeighbourhood;
-
- /**
- * experimental (default: disabled)
- * use exponential connectivity function to model different behavior:
- * closer points will have a stronger connection compared to the linear function.
- * Use ConnRefXValue and ConnX to better parameterize lambda, which controls
- * the decay of the connectivity
- */
- private boolean useExpConnectivity = false;
- private double lambdaConnRefXValue = 0.01;
- private double lambdaConnX = 4;
- private double lamdaConn;
-
-
- /******************************************/
-
-
- /**
- * Wrapper class for data points to store CMM relevant attributes
- *
- */
- protected class CMMPoint extends DataPoint{
- /**
- * Reference to original point
- */
- protected DataPoint p = null;
-
- /**
- * point ID
- */
- protected int pID = 0;
-
-
- /**
- * true class label
- */
- protected int trueClass = -1;
+ /** average knn distance of all points in the cluster */
+ private double knnMeanAvg = 0;
-
- /**
- * the connectivity of the point to its cluster
- */
- protected double connectivity = 1.0;
-
-
- /**
- * knn distnace within own cluster
- */
- protected double knnInCluster = 0.0;
-
-
- /**
- * knn indices (for debugging only)
- */
- protected ArrayList<Integer> knnIndices;
+ /** average deviation of knn distance of all points */
+ private double knnDevAvg = 0;
- public CMMPoint(DataPoint point, int id) {
- //make a copy, but keep reference
- super(point,point.getTimestamp());
- p = point;
- pID = id;
- trueClass = (int)point.classValue();
- }
+ /** connectivity of the cluster to all other clusters */
+ private ArrayList<Double> connections = new ArrayList<Double>();
-
- /**
- * Retruns the current working label of the cluster the point belongs to.
- * The label can change due to merging of clusters.
- *
- * @return the current working class label
- */
- protected int workclass(){
- if(trueClass == -1 )
- return -1;
- else
- return mapTrueLabelToWorkLabel.get(trueClass);
- }
+ private GTCluster(int workclass, int label, int gtClusteringID) {
+ this.orgWorkClass = workclass;
+ this.workclass = workclass;
+ this.label = label;
+ this.clusterRepresentations.add(gtClusteringID);
}
-
-
/**
- * Main class to model the new clusters that will be the output of the cluster analysis
- *
+ * The original class label the cluster represents
+ *
+ * @return original class label
*/
- protected class GTCluster{
- /** points that are per definition in the cluster */
- private ArrayList<Integer> points = new ArrayList<Integer>();
-
- /** a new GT cluster consists of one or more "old" GT clusters.
- * Connected/overlapping clusters cannot be merged directly because of the
- * underlying cluster model. E.g. for merging two spherical clusters the new
- * cluster sphere can cover a lot more space then two separate smaller spheres.
- * To keep the original coverage we need to keep the orignal clusters and merge
- * them on an abstract level. */
- private ArrayList<Integer> clusterRepresentations = new ArrayList<Integer>();
-
- /** current work class (changes when merging) */
- private int workclass;
-
- /** original work class */
- private final int orgWorkClass;
-
- /** original class label*/
- private final int label;
-
- /** clusters that have been merged into this cluster (debugging)*/
- private ArrayList<Integer> mergedWorkLabels = null;
-
- /** average knn distance of all points in the cluster*/
- private double knnMeanAvg = 0;
-
- /** average deviation of knn distance of all points*/
- private double knnDevAvg = 0;
-
- /** connectivity of the cluster to all other clusters */
- private ArrayList<Double> connections = new ArrayList<Double>();
-
-
- private GTCluster(int workclass, int label, int gtClusteringID) {
- this.orgWorkClass = workclass;
- this.workclass = workclass;
- this.label = label;
- this.clusterRepresentations.add(gtClusteringID);
- }
-
-
- /**
- * The original class label the cluster represents
- * @return original class label
- */
- protected int getLabel(){
- return label;
- }
-
- /**
- * Calculate the probability of the point being covered through the cluster
- * @param point to calculate the probability for
- * @return probability of the point being covered through the cluster
- */
- protected double getInclusionProbability(CMMPoint point){
- double prob = Double.MIN_VALUE;
- //check all cluster representatives for coverage
- for (int c = 0; c < clusterRepresentations.size(); c++) {
- double tmp_prob = gtClustering.get(clusterRepresentations.get(c)).getInclusionProbability(point);
- if(tmp_prob > prob) prob = tmp_prob;
- }
- return prob;
- }
-
-
- /**
- * calculate knn distances of points within own cluster
- * + average knn distance and average knn distance deviation of all points
- */
- private void calculateKnn(){
- for (int p0 : points) {
- CMMPoint cmdp = cmmpoints.get(p0);
- if(!cmdp.isNoise()){
- AutoExpandVector<Double> knnDist = new AutoExpandVector<Double>();
- AutoExpandVector<Integer> knnPointIndex = new AutoExpandVector<Integer>();
-
- //calculate nearest neighbours
- getKnnInCluster(cmdp, knnNeighbourhood, points, knnDist,knnPointIndex);
-
- //TODO: What to do if we have less then k neighbours?
- double avgKnn = 0;
- for (int i = 0; i < knnDist.size(); i++) {
- avgKnn+= knnDist.get(i);
- }
- if(knnDist.size()!=0)
- avgKnn/=knnDist.size();
- cmdp.knnInCluster = avgKnn;
- cmdp.knnIndices = knnPointIndex;
- cmdp.p.setMeasureValue("knnAvg", cmdp.knnInCluster);
-
- knnMeanAvg+=avgKnn;
- knnDevAvg+=Math.pow(avgKnn,2);
- }
- }
- knnMeanAvg=knnMeanAvg/(double)points.size();
- knnDevAvg=knnDevAvg/(double)points.size();
-
- double variance = knnDevAvg-Math.pow(knnMeanAvg,2.0);
- // Due to numerical errors, small negative values can occur.
- if (variance <= 0.0) variance = 1e-50;
- knnDevAvg = Math.sqrt(variance);
-
- }
-
-
- /**
- * Calculate the connection of a cluster to this cluster
- * @param otherCid cluster id of the other cluster
- * @param initial flag for initial run
- */
- private void calculateClusterConnection(int otherCid, boolean initial){
- double avgConnection = 0;
- if(workclass==otherCid){
- avgConnection = 1;
- }
- else{
- AutoExpandVector<Double> kmax = new AutoExpandVector<Double>();
- AutoExpandVector<Integer> kmaxIndexes = new AutoExpandVector<Integer>();
-
- for(int p : points){
- CMMPoint cmdp = cmmpoints.get(p);
- double con_p_Cj = getConnectionValue(cmmpoints.get(p), otherCid);
- double connection = cmdp.connectivity * con_p_Cj;
- if(initial){
- cmdp.p.setMeasureValue("Connection to C"+otherCid, con_p_Cj);
- }
-
- //connection
- if(kmax.size() < clusterConnectionMaxPoints || connection > kmax.get(kmax.size()-1)){
- int index = 0;
- while(index < kmax.size() && connection < kmax.get(index)) {
- index++;
- }
- kmax.add(index, connection);
- kmaxIndexes.add(index, p);
- if(kmax.size() > clusterConnectionMaxPoints){
- kmax.remove(kmax.size()-1);
- kmaxIndexes.add(kmaxIndexes.size()-1);
- }
- }
- }
- //connection
- for (int k = 0; k < kmax.size(); k++) {
- avgConnection+= kmax.get(k);
- }
- avgConnection/=kmax.size();
- }
-
- if(otherCid<connections.size()){
- connections.set(otherCid, avgConnection);
- }
- else
- if(connections.size() == otherCid){
- connections.add(avgConnection);
- }
- else
- System.out.println("Something is going really wrong with the connection listing!"+knnNeighbourhood+" "+tauConnection);
- }
-
-
- /**
- * Merge a cluster into this cluster
- * @param mergeID the ID of the cluster to be merged
- */
- private void mergeCluster(int mergeID){
- if(mergeID < gt0Clusters.size()){
- //track merging (debugging)
- for (int i = 0; i < numGTClasses; i++) {
- if(mergeMap[i]==mergeID)
- mergeMap[i]=workclass;
- if(mergeMap[i]>mergeID)
- mergeMap[i]--;
- }
- GTCluster gtcMerge = gt0Clusters.get(mergeID);
- if(debug)
- System.out.println("Merging C"+gtcMerge.workclass+" into C"+workclass+
- " with Con "+connections.get(mergeID)+" / "+gtcMerge.connections.get(workclass));
-
-
- //update mapTrueLabelToWorkLabel
- mapTrueLabelToWorkLabel.put(gtcMerge.label, workclass);
- Iterator iterator = mapTrueLabelToWorkLabel.keySet().iterator();
- while (iterator.hasNext()) {
- Integer key = (Integer)iterator.next();
- //update pointer of already merged cluster
- int value = mapTrueLabelToWorkLabel.get(key);
- if(value == mergeID)
- mapTrueLabelToWorkLabel.put(key, workclass);
- if(value > mergeID)
- mapTrueLabelToWorkLabel.put(key, value-1);
- }
-
- //merge points from B into A
- points.addAll(gtcMerge.points);
- clusterRepresentations.addAll(gtcMerge.clusterRepresentations);
- if(mergedWorkLabels==null){
- mergedWorkLabels = new ArrayList<Integer>();
- }
- mergedWorkLabels.add(gtcMerge.orgWorkClass);
- if(gtcMerge.mergedWorkLabels!=null)
- mergedWorkLabels.addAll(gtcMerge.mergedWorkLabels);
-
- gt0Clusters.remove(mergeID);
-
- //update workclass labels
- for(int c=mergeID; c < gt0Clusters.size(); c++){
- gt0Clusters.get(c).workclass = c;
- }
-
- //update knn distances
- calculateKnn();
- for(int c=0; c < gt0Clusters.size(); c++){
- gt0Clusters.get(c).connections.remove(mergeID);
-
- //recalculate connection from other clusters to the new merged one
- gt0Clusters.get(c).calculateClusterConnection(workclass,false);
- //and from new merged one to other clusters
- gt0Clusters.get(workclass).calculateClusterConnection(c,false);
- }
- }
- else{
- System.out.println("Merge indices are not valid");
- }
- }
+ protected int getLabel() {
+ return label;
}
-
/**
- * @param trueClustering the ground truth clustering
- * @param points data points
- * @param enableClassMerge allow class merging (should be set to true on default)
+ * Calculate the probability of the point being covered through the cluster
+ *
+ * @param point
+ * to calculate the probability for
+ * @return probability of the point being covered through the cluster
*/
- public CMM_GTAnalysis(Clustering trueClustering, ArrayList<DataPoint> points, boolean enableClassMerge){
- if(debug)
- System.out.println("GT Analysis Debug Output");
+ protected double getInclusionProbability(CMMPoint point) {
+ double prob = Double.MIN_VALUE;
+ // check all cluster representatives for coverage
+ for (int c = 0; c < clusterRepresentations.size(); c++) {
+ double tmp_prob = gtClustering.get(clusterRepresentations.get(c)).getInclusionProbability(point);
+ if (tmp_prob > prob)
+ prob = tmp_prob;
+ }
+ return prob;
+ }
- noiseErrorByModel = 0;
- pointErrorByModel = 0;
- if(!enableClassMerge){
- tauConnection = 1.0;
+ /**
+ * calculate knn distances of points within own cluster + average knn
+ * distance and average knn distance deviation of all points
+ */
+ private void calculateKnn() {
+ for (int p0 : points) {
+ CMMPoint cmdp = cmmpoints.get(p0);
+ if (!cmdp.isNoise()) {
+ AutoExpandVector<Double> knnDist = new AutoExpandVector<Double>();
+ AutoExpandVector<Integer> knnPointIndex = new AutoExpandVector<Integer>();
+
+ // calculate nearest neighbours
+ getKnnInCluster(cmdp, knnNeighbourhood, points, knnDist, knnPointIndex);
+
+ // TODO: What to do if we have less then k neighbours?
+ double avgKnn = 0;
+ for (int i = 0; i < knnDist.size(); i++) {
+ avgKnn += knnDist.get(i);
+ }
+ if (knnDist.size() != 0)
+ avgKnn /= knnDist.size();
+ cmdp.knnInCluster = avgKnn;
+ cmdp.knnIndices = knnPointIndex;
+ cmdp.p.setMeasureValue("knnAvg", cmdp.knnInCluster);
+
+ knnMeanAvg += avgKnn;
+ knnDevAvg += Math.pow(avgKnn, 2);
}
+ }
+ knnMeanAvg = knnMeanAvg / (double) points.size();
+ knnDevAvg = knnDevAvg / (double) points.size();
- lamdaConn = -Math.log(lambdaConnRefXValue)/Math.log(2)/lambdaConnX;
-
- this.gtClustering = trueClustering;
+ double variance = knnDevAvg - Math.pow(knnMeanAvg, 2.0);
+ // Due to numerical errors, small negative values can occur.
+ if (variance <= 0.0)
+ variance = 1e-50;
+ knnDevAvg = Math.sqrt(variance);
- numPoints = points.size();
- numDims = points.get(0).numAttributes()-1;
- numGTClusters = gtClustering.size();
+ }
- //init mappings between work and true labels
- mapTrueLabelToWorkLabel = new HashMap<Integer, Integer>();
-
- //set up base of new clustering
- gt0Clusters = new ArrayList<GTCluster>();
- int numWorkClasses = 0;
- //create label to worklabel mapping as real labels can be just a set of unordered integers
- for (int i = 0; i < numGTClusters; i++) {
- int label = (int)gtClustering.get(i).getGroundTruth();
- if(!mapTrueLabelToWorkLabel.containsKey(label)){
- gt0Clusters.add(new GTCluster(numWorkClasses,label,i));
- mapTrueLabelToWorkLabel.put(label,numWorkClasses);
- numWorkClasses++;
+ /**
+ * Calculate the connection of a cluster to this cluster
+ *
+ * @param otherCid
+ * cluster id of the other cluster
+ * @param initial
+ * flag for initial run
+ */
+ private void calculateClusterConnection(int otherCid, boolean initial) {
+ double avgConnection = 0;
+ if (workclass == otherCid) {
+ avgConnection = 1;
+ }
+ else {
+ AutoExpandVector<Double> kmax = new AutoExpandVector<Double>();
+ AutoExpandVector<Integer> kmaxIndexes = new AutoExpandVector<Integer>();
+
+ for (int p : points) {
+ CMMPoint cmdp = cmmpoints.get(p);
+ double con_p_Cj = getConnectionValue(cmmpoints.get(p), otherCid);
+ double connection = cmdp.connectivity * con_p_Cj;
+ if (initial) {
+ cmdp.p.setMeasureValue("Connection to C" + otherCid, con_p_Cj);
+ }
+
+ // connection
+ if (kmax.size() < clusterConnectionMaxPoints || connection > kmax.get(kmax.size() - 1)) {
+ int index = 0;
+ while (index < kmax.size() && connection < kmax.get(index)) {
+ index++;
}
- else{
- gt0Clusters.get(mapTrueLabelToWorkLabel.get(label)).clusterRepresentations.add(i);
+ kmax.add(index, connection);
+ kmaxIndexes.add(index, p);
+ if (kmax.size() > clusterConnectionMaxPoints) {
+ kmax.remove(kmax.size() - 1);
+ kmaxIndexes.add(kmaxIndexes.size() - 1);
}
+ }
}
- numGTClasses = numWorkClasses;
+ // connection
+ for (int k = 0; k < kmax.size(); k++) {
+ avgConnection += kmax.get(k);
+ }
+ avgConnection /= kmax.size();
+ }
- mergeMap = new int[numGTClasses];
+ if (otherCid < connections.size()) {
+ connections.set(otherCid, avgConnection);
+ }
+ else if (connections.size() == otherCid) {
+ connections.add(avgConnection);
+ }
+ else
+ System.out.println("Something is going really wrong with the connection listing!" + knnNeighbourhood + " "
+ + tauConnection);
+ }
+
+ /**
+ * Merge a cluster into this cluster
+ *
+ * @param mergeID
+ * the ID of the cluster to be merged
+ */
+ private void mergeCluster(int mergeID) {
+ if (mergeID < gt0Clusters.size()) {
+ // track merging (debugging)
for (int i = 0; i < numGTClasses; i++) {
- mergeMap[i]=i;
+ if (mergeMap[i] == mergeID)
+ mergeMap[i] = workclass;
+ if (mergeMap[i] > mergeID)
+ mergeMap[i]--;
+ }
+ GTCluster gtcMerge = gt0Clusters.get(mergeID);
+ if (debug)
+ System.out.println("Merging C" + gtcMerge.workclass + " into C" + workclass +
+ " with Con " + connections.get(mergeID) + " / " + gtcMerge.connections.get(workclass));
+
+ // update mapTrueLabelToWorkLabel
+ mapTrueLabelToWorkLabel.put(gtcMerge.label, workclass);
+ Iterator iterator = mapTrueLabelToWorkLabel.keySet().iterator();
+ while (iterator.hasNext()) {
+ Integer key = (Integer) iterator.next();
+ // update pointer of already merged cluster
+ int value = mapTrueLabelToWorkLabel.get(key);
+ if (value == mergeID)
+ mapTrueLabelToWorkLabel.put(key, workclass);
+ if (value > mergeID)
+ mapTrueLabelToWorkLabel.put(key, value - 1);
}
- //create cmd point wrapper instances
- cmmpoints = new ArrayList<CMMPoint>();
- for (int p = 0; p < points.size(); p++) {
- CMMPoint cmdp = new CMMPoint(points.get(p), p);
- cmmpoints.add(cmdp);
+ // merge points from B into A
+ points.addAll(gtcMerge.points);
+ clusterRepresentations.addAll(gtcMerge.clusterRepresentations);
+ if (mergedWorkLabels == null) {
+ mergedWorkLabels = new ArrayList<Integer>();
+ }
+ mergedWorkLabels.add(gtcMerge.orgWorkClass);
+ if (gtcMerge.mergedWorkLabels != null)
+ mergedWorkLabels.addAll(gtcMerge.mergedWorkLabels);
+
+ gt0Clusters.remove(mergeID);
+
+ // update workclass labels
+ for (int c = mergeID; c < gt0Clusters.size(); c++) {
+ gt0Clusters.get(c).workclass = c;
}
+ // update knn distances
+ calculateKnn();
+ for (int c = 0; c < gt0Clusters.size(); c++) {
+ gt0Clusters.get(c).connections.remove(mergeID);
- //split points up into their GTClusters and Noise (according to class labels)
- noise = new ArrayList<Integer>();
- for (int p = 0; p < numPoints; p++) {
- if(cmmpoints.get(p).isNoise()){
- noise.add(p);
- }
- else{
- gt0Clusters.get(cmmpoints.get(p).workclass()).points.add(p);
- }
+ // recalculate connection from other clusters to the new merged one
+ gt0Clusters.get(c).calculateClusterConnection(workclass, false);
+ // and from new merged one to other clusters
+ gt0Clusters.get(workclass).calculateClusterConnection(c, false);
}
+ }
+ else {
+ System.out.println("Merge indices are not valid");
+ }
+ }
+ }
- //calculate initial knnMean and knnDev
- for (GTCluster gtc : gt0Clusters) {
- gtc.calculateKnn();
- }
+ /**
+ * @param trueClustering
+ * the ground truth clustering
+ * @param points
+ * data points
+ * @param enableClassMerge
+ * allow class merging (should be set to true on default)
+ */
+ public CMM_GTAnalysis(Clustering trueClustering, ArrayList<DataPoint> points, boolean enableClassMerge) {
+ if (debug)
+ System.out.println("GT Analysis Debug Output");
- //calculate cluster connections
- calculateGTClusterConnections();
+ noiseErrorByModel = 0;
+ pointErrorByModel = 0;
+ if (!enableClassMerge) {
+ tauConnection = 1.0;
+ }
- //calculate point connections with own clusters
- calculateGTPointQualities();
+ lamdaConn = -Math.log(lambdaConnRefXValue) / Math.log(2) / lambdaConnX;
- if(debug)
- System.out.println("GT Analysis Debug End");
+ this.gtClustering = trueClustering;
- }
+ numPoints = points.size();
+ numDims = points.get(0).numAttributes() - 1;
+ numGTClusters = gtClustering.size();
- /**
- * Calculate the connection of a point to a cluster
- *
- * @param cmmp the point to calculate the connection for
- * @param clusterID the corresponding cluster
- * @return the connection value
+ // init mappings between work and true labels
+ mapTrueLabelToWorkLabel = new HashMap<Integer, Integer>();
+
+ // set up base of new clustering
+ gt0Clusters = new ArrayList<GTCluster>();
+ int numWorkClasses = 0;
+ // create label to worklabel mapping as real labels can be just a set of
+ // unordered integers
+ for (int i = 0; i < numGTClusters; i++) {
+ int label = (int) gtClustering.get(i).getGroundTruth();
+ if (!mapTrueLabelToWorkLabel.containsKey(label)) {
+ gt0Clusters.add(new GTCluster(numWorkClasses, label, i));
+ mapTrueLabelToWorkLabel.put(label, numWorkClasses);
+ numWorkClasses++;
+ }
+ else {
+ gt0Clusters.get(mapTrueLabelToWorkLabel.get(label)).clusterRepresentations.add(i);
+ }
+ }
+ numGTClasses = numWorkClasses;
+
+ mergeMap = new int[numGTClasses];
+ for (int i = 0; i < numGTClasses; i++) {
+ mergeMap[i] = i;
+ }
+
+ // create cmd point wrapper instances
+ cmmpoints = new ArrayList<CMMPoint>();
+ for (int p = 0; p < points.size(); p++) {
+ CMMPoint cmdp = new CMMPoint(points.get(p), p);
+ cmmpoints.add(cmdp);
+ }
+
+ // split points up into their GTClusters and Noise (according to class
+ // labels)
+ noise = new ArrayList<Integer>();
+ for (int p = 0; p < numPoints; p++) {
+ if (cmmpoints.get(p).isNoise()) {
+ noise.add(p);
+ }
+ else {
+ gt0Clusters.get(cmmpoints.get(p).workclass()).points.add(p);
+ }
+ }
+
+ // calculate initial knnMean and knnDev
+ for (GTCluster gtc : gt0Clusters) {
+ gtc.calculateKnn();
+ }
+
+ // calculate cluster connections
+ calculateGTClusterConnections();
+
+ // calculate point connections with own clusters
+ calculateGTPointQualities();
+
+ if (debug)
+ System.out.println("GT Analysis Debug End");
+
+ }
+
+ /**
+ * Calculate the connection of a point to a cluster
+ *
+ * @param cmmp
+ * the point to calculate the connection for
+ * @param clusterID
+ * the corresponding cluster
+ * @return the connection value
+ */
+ // TODO: Cache the connection value for a point to the different clusters???
+ protected double getConnectionValue(CMMPoint cmmp, int clusterID) {
+ AutoExpandVector<Double> knnDist = new AutoExpandVector<Double>();
+ AutoExpandVector<Integer> knnPointIndex = new AutoExpandVector<Integer>();
+
+ // calculate the knn distance of the point to the cluster
+ getKnnInCluster(cmmp, knnNeighbourhood, gt0Clusters.get(clusterID).points, knnDist, knnPointIndex);
+
+ // TODO: What to do if we have less then k neighbors?
+ double avgDist = 0;
+ for (int i = 0; i < knnDist.size(); i++) {
+ avgDist += knnDist.get(i);
+ }
+ // what to do if we only have a single point???
+ if (knnDist.size() != 0)
+ avgDist /= knnDist.size();
+ else
+ return 0;
+
+ // get the upper knn distance of the cluster
+ double upperKnn = gt0Clusters.get(clusterID).knnMeanAvg + gt0Clusters.get(clusterID).knnDevAvg;
+
+ /*
+ * calculate the connectivity based on knn distance of the point within the
+ * cluster and the upper knn distance of the cluster
*/
- //TODO: Cache the connection value for a point to the different clusters???
- protected double getConnectionValue(CMMPoint cmmp, int clusterID){
- AutoExpandVector<Double> knnDist = new AutoExpandVector<Double>();
- AutoExpandVector<Integer> knnPointIndex = new AutoExpandVector<Integer>();
-
- //calculate the knn distance of the point to the cluster
- getKnnInCluster(cmmp, knnNeighbourhood, gt0Clusters.get(clusterID).points, knnDist, knnPointIndex);
+ if (avgDist < upperKnn) {
+ return 1;
+ }
+ else {
+ // value that should be reached at upperKnn distance
+ // Choose connection formula
+ double conn;
+ if (useExpConnectivity)
+ conn = Math.pow(2, -lamdaConn * (avgDist - upperKnn) / upperKnn);
+ else
+ conn = upperKnn / avgDist;
- //TODO: What to do if we have less then k neighbors?
- double avgDist = 0;
- for (int i = 0; i < knnDist.size(); i++) {
- avgDist+= knnDist.get(i);
- }
- //what to do if we only have a single point???
- if(knnDist.size()!=0)
- avgDist/=knnDist.size();
- else
- return 0;
+ if (Double.isNaN(conn))
+ System.out.println("Connectivity NaN at " + cmmp.p.getTimestamp());
- //get the upper knn distance of the cluster
- double upperKnn = gt0Clusters.get(clusterID).knnMeanAvg + gt0Clusters.get(clusterID).knnDevAvg;
-
- /* calculate the connectivity based on knn distance of the point within the cluster
- and the upper knn distance of the cluster*/
- if(avgDist < upperKnn){
- return 1;
+ return conn;
+ }
+ }
+
+ /**
+ * @param cmmp
+ * point to calculate knn distance for
+ * @param k
+ * number of nearest neighbors to look for
+ * @param pointIDs
+ * list of point IDs to check
+ * @param knnDist
+ * sorted list of smallest knn distances (can already be filled to
+ * make updates possible)
+ * @param knnPointIndex
+ * list of corresponding knn indices
+ */
+ private void getKnnInCluster(CMMPoint cmmp, int k,
+ ArrayList<Integer> pointIDs,
+ AutoExpandVector<Double> knnDist,
+ AutoExpandVector<Integer> knnPointIndex) {
+
+ // iterate over every point in the choosen cluster, cal distance and insert
+ // into list
+ for (int p1 = 0; p1 < pointIDs.size(); p1++) {
+ int pid = pointIDs.get(p1);
+ if (cmmp.pID == pid)
+ continue;
+ double dist = distance(cmmp, cmmpoints.get(pid));
+ if (knnDist.size() < k || dist < knnDist.get(knnDist.size() - 1)) {
+ int index = 0;
+ while (index < knnDist.size() && dist > knnDist.get(index)) {
+ index++;
}
- else{
- //value that should be reached at upperKnn distance
- //Choose connection formula
- double conn;
- if(useExpConnectivity)
- conn = Math.pow(2,-lamdaConn*(avgDist-upperKnn)/upperKnn);
+ knnDist.add(index, dist);
+ knnPointIndex.add(index, pid);
+ if (knnDist.size() > k) {
+ knnDist.remove(knnDist.size() - 1);
+ knnPointIndex.remove(knnPointIndex.size() - 1);
+ }
+ }
+ }
+ }
+
+ /**
+ * calculate initial connectivities
+ */
+ private void calculateGTPointQualities() {
+ for (int p = 0; p < numPoints; p++) {
+ CMMPoint cmdp = cmmpoints.get(p);
+ if (!cmdp.isNoise()) {
+ cmdp.connectivity = getConnectionValue(cmdp, cmdp.workclass());
+ cmdp.p.setMeasureValue("Connectivity", cmdp.connectivity);
+ }
+ }
+ }
+
+ /**
+ * Calculate connections between clusters and merge clusters accordingly as
+ * long as connections exceed threshold
+ */
+ private void calculateGTClusterConnections() {
+ for (int c0 = 0; c0 < gt0Clusters.size(); c0++) {
+ for (int c1 = 0; c1 < gt0Clusters.size(); c1++) {
+ gt0Clusters.get(c0).calculateClusterConnection(c1, true);
+ }
+ }
+
+ boolean changedConnection = true;
+ while (changedConnection) {
+ if (debug) {
+ System.out.println("Cluster Connection");
+ for (int c = 0; c < gt0Clusters.size(); c++) {
+ System.out.print("C" + gt0Clusters.get(c).label + " --> ");
+ for (int c1 = 0; c1 < gt0Clusters.get(c).connections.size(); c1++) {
+ System.out.print(" C" + gt0Clusters.get(c1).label + ": " + gt0Clusters.get(c).connections.get(c1));
+ }
+ System.out.println("");
+ }
+ System.out.println("");
+ }
+
+ double max = 0;
+ int maxIndexI = -1;
+ int maxIndexJ = -1;
+
+ changedConnection = false;
+ for (int c0 = 0; c0 < gt0Clusters.size(); c0++) {
+ for (int c1 = c0 + 1; c1 < gt0Clusters.size(); c1++) {
+ if (c0 == c1)
+ continue;
+ double min = Math.min(gt0Clusters.get(c0).connections.get(c1), gt0Clusters.get(c1).connections.get(c0));
+ if (min > max) {
+ max = min;
+ maxIndexI = c0;
+ maxIndexJ = c1;
+ }
+ }
+ }
+ if (maxIndexI != -1 && max > tauConnection) {
+ gt0Clusters.get(maxIndexI).mergeCluster(maxIndexJ);
+ if (debug)
+ System.out.println("Merging " + maxIndexI + " and " + maxIndexJ + " because of connection " + max);
+
+ changedConnection = true;
+ }
+ }
+ numGT0Classes = gt0Clusters.size();
+ }
+
+ /**
+ * Calculates how well the original clusters are separable. Small values
+ * indicate bad separability, values close to 1 indicate good separability
+ *
+ * @return index of seperability
+ */
+ public double getClassSeparability() {
+ // int totalConn = numGTClasses*(numGTClasses-1)/2;
+ // int mergedConn = 0;
+ // for(GTCluster gt : gt0Clusters){
+ // int merged = gt.clusterRepresentations.size();
+ // if(merged > 1)
+ // mergedConn+=merged * (merged-1)/2;
+ // }
+ // if(totalConn == 0)
+ // return 0;
+ // else
+ // return 1-mergedConn/(double)totalConn;
+ return numGT0Classes / (double) numGTClasses;
+
+ }
+
+ /**
+ * Calculates how well noise is separable from the given clusters Small values
+ * indicate bad separability, values close to 1 indicate good separability
+ *
+ * @return index of noise separability
+ */
+ public double getNoiseSeparability() {
+ if (noise.isEmpty())
+ return 1;
+
+ double connectivity = 0;
+ for (int p : noise) {
+ CMMPoint npoint = cmmpoints.get(p);
+ double maxConnection = 0;
+
+ // TODO: some kind of pruning possible. what about weighting?
+ for (int c = 0; c < gt0Clusters.size(); c++) {
+ double connection = getConnectionValue(npoint, c);
+ if (connection > maxConnection)
+ maxConnection = connection;
+ }
+ connectivity += maxConnection;
+ npoint.p.setMeasureValue("MaxConnection", maxConnection);
+ }
+
+ return 1 - (connectivity / noise.size());
+ }
+
+ /**
+ * Calculates the relative number of errors being caused by the underlying
+ * cluster model
+ *
+ * @return quality of the model
+ */
+ public double getModelQuality() {
+ for (int p = 0; p < numPoints; p++) {
+ CMMPoint cmdp = cmmpoints.get(p);
+ for (int hc = 0; hc < numGTClusters; hc++) {
+ if (gtClustering.get(hc).getGroundTruth() != cmdp.trueClass) {
+ if (gtClustering.get(hc).getInclusionProbability(cmdp) >= 1) {
+ if (!cmdp.isNoise())
+ pointErrorByModel++;
else
- conn = upperKnn/avgDist;
-
- if(Double.isNaN(conn))
- System.out.println("Connectivity NaN at "+cmmp.p.getTimestamp());
-
- return conn;
+ noiseErrorByModel++;
+ break;
+ }
}
+ }
}
+ if (debug)
+ System.out.println("Error by model: noise " + noiseErrorByModel + " point " + pointErrorByModel);
-
- /**
- * @param cmmp point to calculate knn distance for
- * @param k number of nearest neighbors to look for
- * @param pointIDs list of point IDs to check
- * @param knnDist sorted list of smallest knn distances (can already be filled to make updates possible)
- * @param knnPointIndex list of corresponding knn indices
- */
- private void getKnnInCluster(CMMPoint cmmp, int k,
- ArrayList<Integer> pointIDs,
- AutoExpandVector<Double> knnDist,
- AutoExpandVector<Integer> knnPointIndex) {
+ return 1 - ((pointErrorByModel + noiseErrorByModel) / (double) numPoints);
+ }
- //iterate over every point in the choosen cluster, cal distance and insert into list
- for (int p1 = 0; p1 < pointIDs.size(); p1++) {
- int pid = pointIDs.get(p1);
- if(cmmp.pID == pid) continue;
- double dist = distance(cmmp,cmmpoints.get(pid));
- if(knnDist.size() < k || dist < knnDist.get(knnDist.size()-1)){
- int index = 0;
- while(index < knnDist.size() && dist > knnDist.get(index)) {
- index++;
- }
- knnDist.add(index, dist);
- knnPointIndex.add(index,pid);
- if(knnDist.size() > k){
- knnDist.remove(knnDist.size()-1);
- knnPointIndex.remove(knnPointIndex.size()-1);
- }
- }
- }
+ /**
+ * Get CMM internal point
+ *
+ * @param index
+ * of the point
+ * @return cmm point
+ */
+ protected CMMPoint getPoint(int index) {
+ return cmmpoints.get(index);
+ }
+
+ /**
+ * Return cluster
+ *
+ * @param index
+ * of the cluster to return
+ * @return cluster
+ */
+ protected GTCluster getGT0Cluster(int index) {
+ return gt0Clusters.get(index);
+ }
+
+ /**
+ * Number of classes/clusters of the new clustering
+ *
+ * @return number of new clusters
+ */
+ protected int getNumberOfGT0Classes() {
+ return numGT0Classes;
+ }
+
+ /**
+ * Calculates Euclidian distance
+ *
+ * @param inst1
+ * point as double array
+ * @param inst2
+ * point as double array
+ * @return euclidian distance
+ */
+ private double distance(Instance inst1, Instance inst2) {
+ return distance(inst1, inst2.toDoubleArray());
+
+ }
+
+ /**
+ * Calculates Euclidian distance
+ *
+ * @param inst1
+ * point as an instance
+ * @param inst2
+ * point as double array
+ * @return euclidian distance
+ */
+ private double distance(Instance inst1, double[] inst2) {
+ double distance = 0.0;
+ for (int i = 0; i < numDims; i++) {
+ double d = inst1.value(i) - inst2[i];
+ distance += d * d;
}
+ return Math.sqrt(distance);
+ }
-
-
- /**
- * calculate initial connectivities
- */
- private void calculateGTPointQualities(){
- for (int p = 0; p < numPoints; p++) {
- CMMPoint cmdp = cmmpoints.get(p);
- if(!cmdp.isNoise()){
- cmdp.connectivity = getConnectionValue(cmdp, cmdp.workclass());
- cmdp.p.setMeasureValue("Connectivity", cmdp.connectivity);
- }
- }
+ /**
+ * String with main CMM parameters
+ *
+ * @return main CMM parameter
+ */
+ public String getParameterString() {
+ String para = "";
+ para += "k=" + knnNeighbourhood + ";";
+ if (useExpConnectivity) {
+ para += "lambdaConnX=" + lambdaConnX + ";";
+ para += "lambdaConn=" + lamdaConn + ";";
+ para += "lambdaConnRef=" + lambdaConnRefXValue + ";";
}
+ para += "m=" + clusterConnectionMaxPoints + ";";
+ para += "tauConn=" + tauConnection + ";";
-
-
- /**
- * Calculate connections between clusters and merge clusters accordingly as
- * long as connections exceed threshold
- */
- private void calculateGTClusterConnections(){
- for (int c0 = 0; c0 < gt0Clusters.size(); c0++) {
- for (int c1 = 0; c1 < gt0Clusters.size(); c1++) {
- gt0Clusters.get(c0).calculateClusterConnection(c1, true);
- }
- }
-
- boolean changedConnection = true;
- while(changedConnection){
- if(debug){
- System.out.println("Cluster Connection");
- for (int c = 0; c < gt0Clusters.size(); c++) {
- System.out.print("C"+gt0Clusters.get(c).label+" --> ");
- for (int c1 = 0; c1 < gt0Clusters.get(c).connections.size(); c1++) {
- System.out.print(" C"+gt0Clusters.get(c1).label+": "+gt0Clusters.get(c).connections.get(c1));
- }
- System.out.println("");
- }
- System.out.println("");
- }
-
- double max = 0;
- int maxIndexI = -1;
- int maxIndexJ = -1;
-
- changedConnection = false;
- for (int c0 = 0; c0 < gt0Clusters.size(); c0++) {
- for (int c1 = c0+1; c1 < gt0Clusters.size(); c1++) {
- if(c0==c1) continue;
- double min =Math.min(gt0Clusters.get(c0).connections.get(c1), gt0Clusters.get(c1).connections.get(c0));
- if(min > max){
- max = min;
- maxIndexI = c0;
- maxIndexJ = c1;
- }
- }
- }
- if(maxIndexI!=-1 && max > tauConnection){
- gt0Clusters.get(maxIndexI).mergeCluster(maxIndexJ);
- if(debug)
- System.out.println("Merging "+maxIndexI+" and "+maxIndexJ+" because of connection "+max);
-
- changedConnection = true;
- }
- }
- numGT0Classes = gt0Clusters.size();
- }
-
-
- /**
- * Calculates how well the original clusters are separable.
- * Small values indicate bad separability, values close to 1 indicate good separability
- * @return index of seperability
- */
- public double getClassSeparability(){
-// int totalConn = numGTClasses*(numGTClasses-1)/2;
-// int mergedConn = 0;
-// for(GTCluster gt : gt0Clusters){
-// int merged = gt.clusterRepresentations.size();
-// if(merged > 1)
-// mergedConn+=merged * (merged-1)/2;
-// }
-// if(totalConn == 0)
-// return 0;
-// else
-// return 1-mergedConn/(double)totalConn;
- return numGT0Classes/(double)numGTClasses;
-
- }
-
-
- /**
- * Calculates how well noise is separable from the given clusters
- * Small values indicate bad separability, values close to 1 indicate good separability
- * @return index of noise separability
- */
- public double getNoiseSeparability(){
- if(noise.isEmpty())
- return 1;
-
- double connectivity = 0;
- for(int p : noise){
- CMMPoint npoint = cmmpoints.get(p);
- double maxConnection = 0;
-
- //TODO: some kind of pruning possible. what about weighting?
- for (int c = 0; c < gt0Clusters.size(); c++) {
- double connection = getConnectionValue(npoint, c);
- if(connection > maxConnection)
- maxConnection = connection;
- }
- connectivity+=maxConnection;
- npoint.p.setMeasureValue("MaxConnection", maxConnection);
- }
-
- return 1-(connectivity / noise.size());
- }
-
-
- /**
- * Calculates the relative number of errors being caused by the underlying cluster model
- * @return quality of the model
- */
- public double getModelQuality(){
- for(int p = 0; p < numPoints; p++){
- CMMPoint cmdp = cmmpoints.get(p);
- for(int hc = 0; hc < numGTClusters;hc++){
- if(gtClustering.get(hc).getGroundTruth() != cmdp.trueClass){
- if(gtClustering.get(hc).getInclusionProbability(cmdp) >= 1){
- if(!cmdp.isNoise())
- pointErrorByModel++;
- else
- noiseErrorByModel++;
- break;
- }
- }
- }
- }
- if(debug)
- System.out.println("Error by model: noise "+noiseErrorByModel+" point "+pointErrorByModel);
-
- return 1-((pointErrorByModel + noiseErrorByModel)/(double) numPoints);
- }
-
-
- /**
- * Get CMM internal point
- * @param index of the point
- * @return cmm point
- */
- protected CMMPoint getPoint(int index){
- return cmmpoints.get(index);
- }
-
-
- /**
- * Return cluster
- * @param index of the cluster to return
- * @return cluster
- */
- protected GTCluster getGT0Cluster(int index){
- return gt0Clusters.get(index);
- }
-
- /**
- * Number of classes/clusters of the new clustering
- * @return number of new clusters
- */
- protected int getNumberOfGT0Classes() {
- return numGT0Classes;
- }
-
- /**
- * Calculates Euclidian distance
- * @param inst1 point as double array
- * @param inst2 point as double array
- * @return euclidian distance
- */
- private double distance(Instance inst1, Instance inst2){
- return distance(inst1, inst2.toDoubleArray());
-
- }
-
- /**
- * Calculates Euclidian distance
- * @param inst1 point as an instance
- * @param inst2 point as double array
- * @return euclidian distance
- */
- private double distance(Instance inst1, double[] inst2){
- double distance = 0.0;
- for (int i = 0; i < numDims; i++) {
- double d = inst1.value(i) - inst2[i];
- distance += d * d;
- }
- return Math.sqrt(distance);
- }
-
- /**
- * String with main CMM parameters
- * @return main CMM parameter
- */
- public String getParameterString(){
- String para = "";
- para+="k="+knnNeighbourhood+";";
- if(useExpConnectivity){
- para+="lambdaConnX="+lambdaConnX+";";
- para+="lambdaConn="+lamdaConn+";";
- para+="lambdaConnRef="+lambdaConnRefXValue+";";
- }
- para+="m="+clusterConnectionMaxPoints+";";
- para+="tauConn="+tauConnection+";";
-
- return para;
- }
+ return para;
+ }
}
-
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/EntropyCollection.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/EntropyCollection.java
index 0d311e4..1a44542 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/EntropyCollection.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/EntropyCollection.java
@@ -30,145 +30,146 @@
import com.yahoo.labs.samoa.moa.evaluation.MeasureCollection;
import com.yahoo.labs.samoa.moa.evaluation.MembershipMatrix;
-public class EntropyCollection extends MeasureCollection{
+public class EntropyCollection extends MeasureCollection {
- private static final Logger logger = LoggerFactory.getLogger(EntropyCollection.class);
+ private static final Logger logger = LoggerFactory.getLogger(EntropyCollection.class);
- @Override
- protected String[] getNames() {
- return new String[]{"GT cross entropy","FC cross entropy","Homogeneity","Completeness","V-Measure","VarInformation"};
+ @Override
+ protected String[] getNames() {
+ return new String[] { "GT cross entropy", "FC cross entropy", "Homogeneity", "Completeness", "V-Measure",
+ "VarInformation" };
+ }
+
+ @Override
+ protected boolean[] getDefaultEnabled() {
+ return new boolean[] { false, false, false, false, false, false };
+ }
+
+ @Override
+ public void evaluateClustering(Clustering fclustering, Clustering hClustering, ArrayList<DataPoint> points)
+ throws Exception {
+
+ MembershipMatrix mm = new MembershipMatrix(fclustering, points);
+ int numClasses = mm.getNumClasses();
+ int numCluster = fclustering.size() + 1;
+ int n = mm.getTotalEntries();
+
+ double FCentropy = 0;
+ if (numCluster > 1) {
+ for (int fc = 0; fc < numCluster; fc++) {
+ double weight = mm.getClusterSum(fc) / (double) n;
+ if (weight > 0)
+ FCentropy += weight * Math.log10(weight);
+ }
+ FCentropy /= (-1 * Math.log10(numCluster));
}
- @Override
- protected boolean[] getDefaultEnabled() {
- return new boolean[]{false, false, false, false, false, false};
+ logger.debug("FC entropy: {}", FCentropy);
+
+ double GTentropy = 0;
+ if (numClasses > 1) {
+ for (int hc = 0; hc < numClasses; hc++) {
+ double weight = mm.getClassSum(hc) / (double) n;
+ if (weight > 0)
+ GTentropy += weight * Math.log10(weight);
+ }
+ GTentropy /= (-1 * Math.log10(numClasses));
}
- @Override
- public void evaluateClustering(Clustering fclustering, Clustering hClustering, ArrayList<DataPoint> points) throws Exception {
+ logger.debug("GT entropy: {}", GTentropy);
- MembershipMatrix mm = new MembershipMatrix(fclustering, points);
- int numClasses = mm.getNumClasses();
- int numCluster = fclustering.size()+1;
- int n = mm.getTotalEntries();
+ // cluster based entropy
+ double FCcrossEntropy = 0;
-
- double FCentropy = 0;
- if(numCluster > 1){
- for (int fc = 0; fc < numCluster; fc++){
- double weight = mm.getClusterSum(fc)/(double)n;
- if(weight > 0)
- FCentropy+= weight * Math.log10(weight);
- }
- FCentropy/=(-1*Math.log10(numCluster));
+ for (int fc = 0; fc < numCluster; fc++) {
+ double e = 0;
+ int clusterWeight = mm.getClusterSum(fc);
+ if (clusterWeight > 0) {
+ for (int hc = 0; hc < numClasses; hc++) {
+ double p = mm.getClusterClassWeight(fc, hc) / (double) clusterWeight;
+ if (p != 0) {
+ e += p * Math.log10(p);
+ }
}
-
- logger.debug("FC entropy: {}", FCentropy);
-
- double GTentropy = 0;
- if(numClasses > 1){
- for (int hc = 0; hc < numClasses; hc++){
- double weight = mm.getClassSum(hc)/(double)n;
- if(weight > 0)
- GTentropy+= weight * Math.log10(weight);
- }
- GTentropy/=(-1*Math.log10(numClasses));
- }
-
- logger.debug("GT entropy: {}", GTentropy);
-
- //cluster based entropy
- double FCcrossEntropy = 0;
-
- for (int fc = 0; fc < numCluster; fc++){
- double e = 0;
- int clusterWeight = mm.getClusterSum(fc);
- if(clusterWeight>0){
- for (int hc = 0; hc < numClasses; hc++) {
- double p = mm.getClusterClassWeight(fc, hc)/(double)clusterWeight;
- if(p!=0){
- e+=p * Math.log10(p);
- }
- }
- FCcrossEntropy+=((clusterWeight/(double)n) * e);
- }
- }
- if(numCluster > 1){
- FCcrossEntropy/=-1*Math.log10(numCluster);
- }
-
- addValue("FC cross entropy", 1-FCcrossEntropy);
- logger.debug("FC cross entropy: {}", 1 - FCcrossEntropy);
-
- //class based entropy
- double GTcrossEntropy = 0;
- for (int hc = 0; hc < numClasses; hc++){
- double e = 0;
- int classWeight = mm.getClassSum(hc);
- if(classWeight>0){
- for (int fc = 0; fc < numCluster; fc++) {
- double p = mm.getClusterClassWeight(fc, hc)/(double)classWeight;
- if(p!=0){
- e+=p * Math.log10(p);
- }
- }
- }
- GTcrossEntropy+=((classWeight/(double)n) * e);
- }
- if(numClasses > 1)
- GTcrossEntropy/=-1*Math.log10(numClasses);
- addValue("GT cross entropy", 1-GTcrossEntropy);
- logger.debug("GT cross entropy: {}", 1 - GTcrossEntropy);
-
- double homogeneity;
- if(FCentropy == 0)
- homogeneity = 1;
- else
- homogeneity = 1 - FCcrossEntropy/FCentropy;
-
- //TODO set err values for now, needs to be debugged
- if(homogeneity > 1 || homogeneity < 0)
- addValue("Homogeneity",-1);
- else
- addValue("Homogeneity",homogeneity);
-
- double completeness;
- if(GTentropy == 0)
- completeness = 1;
- else
- completeness = 1 - GTcrossEntropy/GTentropy;
- addValue("Completeness",completeness);
-
- double beta = 1;
- double vmeasure = (1+ beta)*homogeneity*completeness/(beta *homogeneity+completeness);
-
- if(vmeasure > 1 || homogeneity < 0)
- addValue("V-Measure",-1);
- else
- addValue("V-Measure",vmeasure);
-
-
-
- double mutual = 0;
- for (int i = 0; i < numCluster; i++){
- for (int j = 0; j < numClasses; j++) {
- if(mm.getClusterClassWeight(i, j)==0) continue;
- double m = Math.log10(mm.getClusterClassWeight(i, j)/(double)mm.getClusterSum(i)/(double)mm.getClassSum(j)*(double)n);
- m*= mm.getClusterClassWeight(i, j)/(double)n;
- logger.debug("( {} / {}): ",m, m);
- mutual+=m;
- }
- }
- if(numClasses > 1)
- mutual/=Math.log10(numClasses);
-
- double varInfo = 1;
- if(FCentropy + GTentropy > 0)
- varInfo = 2*mutual/(FCentropy + GTentropy);
-
- logger.debug("mutual: {} / VI: {}", mutual, varInfo);
- addValue("VarInformation", varInfo);
-
+ FCcrossEntropy += ((clusterWeight / (double) n) * e);
+ }
}
+ if (numCluster > 1) {
+ FCcrossEntropy /= -1 * Math.log10(numCluster);
+ }
+
+ addValue("FC cross entropy", 1 - FCcrossEntropy);
+ logger.debug("FC cross entropy: {}", 1 - FCcrossEntropy);
+
+ // class based entropy
+ double GTcrossEntropy = 0;
+ for (int hc = 0; hc < numClasses; hc++) {
+ double e = 0;
+ int classWeight = mm.getClassSum(hc);
+ if (classWeight > 0) {
+ for (int fc = 0; fc < numCluster; fc++) {
+ double p = mm.getClusterClassWeight(fc, hc) / (double) classWeight;
+ if (p != 0) {
+ e += p * Math.log10(p);
+ }
+ }
+ }
+ GTcrossEntropy += ((classWeight / (double) n) * e);
+ }
+ if (numClasses > 1)
+ GTcrossEntropy /= -1 * Math.log10(numClasses);
+ addValue("GT cross entropy", 1 - GTcrossEntropy);
+ logger.debug("GT cross entropy: {}", 1 - GTcrossEntropy);
+
+ double homogeneity;
+ if (FCentropy == 0)
+ homogeneity = 1;
+ else
+ homogeneity = 1 - FCcrossEntropy / FCentropy;
+
+ // TODO set err values for now, needs to be debugged
+ if (homogeneity > 1 || homogeneity < 0)
+ addValue("Homogeneity", -1);
+ else
+ addValue("Homogeneity", homogeneity);
+
+ double completeness;
+ if (GTentropy == 0)
+ completeness = 1;
+ else
+ completeness = 1 - GTcrossEntropy / GTentropy;
+ addValue("Completeness", completeness);
+
+ double beta = 1;
+ double vmeasure = (1 + beta) * homogeneity * completeness / (beta * homogeneity + completeness);
+
+ if (vmeasure > 1 || homogeneity < 0)
+ addValue("V-Measure", -1);
+ else
+ addValue("V-Measure", vmeasure);
+
+ double mutual = 0;
+ for (int i = 0; i < numCluster; i++) {
+ for (int j = 0; j < numClasses; j++) {
+ if (mm.getClusterClassWeight(i, j) == 0)
+ continue;
+ double m = Math.log10(mm.getClusterClassWeight(i, j) / (double) mm.getClusterSum(i)
+ / (double) mm.getClassSum(j) * (double) n);
+ m *= mm.getClusterClassWeight(i, j) / (double) n;
+ logger.debug("( {} / {}): ", m, m);
+ mutual += m;
+ }
+ }
+ if (numClasses > 1)
+ mutual /= Math.log10(numClasses);
+
+ double varInfo = 1;
+ if (FCentropy + GTentropy > 0)
+ varInfo = 2 * mutual / (FCentropy + GTentropy);
+
+ logger.debug("mutual: {} / VI: {}", mutual, varInfo);
+ addValue("VarInformation", varInfo);
+
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/F1.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/F1.java
index 6533f36..f62b6bb 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/F1.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/F1.java
@@ -26,90 +26,85 @@
import com.yahoo.labs.samoa.moa.core.DataPoint;
import java.util.ArrayList;
+public class F1 extends MeasureCollection {
-public class F1 extends MeasureCollection{
+ @Override
+ protected String[] getNames() {
+ return new String[] { "F1-P", "F1-R", "Purity" };
+ }
- @Override
- protected String[] getNames() {
- return new String[]{"F1-P","F1-R","Purity"};
+ public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) {
+
+ if (clustering.size() < 0) {
+ addValue(0, 0);
+ addValue(1, 0);
+ return;
}
- public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) {
+ MembershipMatrix mm = new MembershipMatrix(clustering, points);
+ // System.out.println(mm.toString());
- if (clustering.size()<0){
- addValue(0,0);
- addValue(1,0);
- return;
+ int numClasses = mm.getNumClasses();
+ if (mm.hasNoiseClass())
+ numClasses--;
+
+ // F1 as defined in P3C, try using F1 optimization
+ double F1_P = 0.0;
+ double purity = 0;
+ int realClusters = 0;
+ for (int i = 0; i < clustering.size(); i++) {
+ int max_weight = 0;
+ int max_weight_index = -1;
+
+ // find max index
+ for (int j = 0; j < numClasses; j++) {
+ if (mm.getClusterClassWeight(i, j) > max_weight) {
+ max_weight = mm.getClusterClassWeight(i, j);
+ max_weight_index = j;
}
-
- MembershipMatrix mm = new MembershipMatrix(clustering, points);
- //System.out.println(mm.toString());
-
- int numClasses = mm.getNumClasses();
- if(mm.hasNoiseClass())
- numClasses--;
-
-
-
- //F1 as defined in P3C, try using F1 optimization
- double F1_P = 0.0;
- double purity = 0;
- int realClusters = 0;
- for (int i = 0; i < clustering.size(); i++) {
- int max_weight = 0;
- int max_weight_index = -1;
-
- //find max index
- for (int j = 0; j < numClasses; j++) {
- if(mm.getClusterClassWeight(i, j) > max_weight){
- max_weight = mm.getClusterClassWeight(i, j);
- max_weight_index = j;
- }
- }
- if(max_weight_index!=-1){
- realClusters++;
- double precision = mm.getClusterClassWeight(i, max_weight_index)/(double)mm.getClusterSum(i);
- double recall = mm.getClusterClassWeight(i, max_weight_index)/(double) mm.getClassSum(max_weight_index);
- double f1 = 0;
- if(precision > 0 || recall > 0){
- f1 = 2*precision*recall/(precision+recall);
- }
- F1_P += f1;
- purity += precision;
-
- //TODO should we move setMeasure stuff into the Cluster interface?
- clustering.get(i).setMeasureValue("F1-P", Double.toString(f1));
- }
+ }
+ if (max_weight_index != -1) {
+ realClusters++;
+ double precision = mm.getClusterClassWeight(i, max_weight_index) / (double) mm.getClusterSum(i);
+ double recall = mm.getClusterClassWeight(i, max_weight_index) / (double) mm.getClassSum(max_weight_index);
+ double f1 = 0;
+ if (precision > 0 || recall > 0) {
+ f1 = 2 * precision * recall / (precision + recall);
}
- if(realClusters > 0){
- F1_P/=realClusters;
- purity/=realClusters;
- }
- addValue("F1-P",F1_P);
- addValue("Purity",purity);
+ F1_P += f1;
+ purity += precision;
-
-
- //F1 as defined in .... mainly maximizes F1 for each class
- double F1_R = 0.0;
- for (int j = 0; j < numClasses; j++) {
- double max_f1 = 0;
- for (int i = 0; i < clustering.size(); i++) {
- double precision = mm.getClusterClassWeight(i, j)/(double)mm.getClusterSum(i);
- double recall = mm.getClusterClassWeight(i, j)/(double)mm.getClassSum(j);
- double f1 = 0;
- if(precision > 0 || recall > 0){
- f1 = 2*precision*recall/(precision+recall);
- }
- if(max_f1 < f1){
- max_f1 = f1;
- }
- }
- F1_R+= max_f1;
- }
- F1_R/=numClasses;
-
- addValue("F1-R",F1_R);
+ // TODO should we move setMeasure stuff into the Cluster interface?
+ clustering.get(i).setMeasureValue("F1-P", Double.toString(f1));
+ }
}
+ if (realClusters > 0) {
+ F1_P /= realClusters;
+ purity /= realClusters;
+ }
+ addValue("F1-P", F1_P);
+ addValue("Purity", purity);
+
+ // F1 as defined in .... mainly maximizes F1 for each class
+ double F1_R = 0.0;
+ for (int j = 0; j < numClasses; j++) {
+ double max_f1 = 0;
+ for (int i = 0; i < clustering.size(); i++) {
+ double precision = mm.getClusterClassWeight(i, j) / (double) mm.getClusterSum(i);
+ double recall = mm.getClusterClassWeight(i, j) / (double) mm.getClassSum(j);
+ double f1 = 0;
+ if (precision > 0 || recall > 0) {
+ f1 = 2 * precision * recall / (precision + recall);
+ }
+ if (max_f1 < f1) {
+ max_f1 = f1;
+ }
+ }
+ F1_R += max_f1;
+ }
+ F1_R /= numClasses;
+
+ addValue("F1-R", F1_R);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/General.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/General.java
index 7f23c1b..287af06 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/General.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/General.java
@@ -20,7 +20,6 @@
* #L%
*/
-
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.moa.cluster.Clustering;
import com.yahoo.labs.samoa.moa.cluster.SphereCluster;
@@ -28,164 +27,166 @@
import com.yahoo.labs.samoa.moa.core.DataPoint;
import java.util.ArrayList;
+public class General extends MeasureCollection {
+ private int numPoints;
+ private int numFClusters;
+ private int numDims;
+ private double pointInclusionProbThreshold = 0.8;
+ private Clustering clustering;
+ private ArrayList<DataPoint> points;
-public class General extends MeasureCollection{
- private int numPoints;
- private int numFClusters;
- private int numDims;
- private double pointInclusionProbThreshold = 0.8;
- private Clustering clustering;
- private ArrayList<DataPoint> points;
+ public General() {
+ super();
+ }
+ @Override
+ protected String[] getNames() {
+ // String[] names =
+ // {"GPrecision","GRecall","Redundancy","Overlap","numCluster","numClasses","Compactness"};
+ return new String[] { "GPrecision", "GRecall", "Redundancy", "numCluster", "numClasses" };
+ }
- public General() {
- super();
+ // @Override
+ // protected boolean[] getDefaultEnabled() {
+ // boolean [] defaults = {false, false, false, false, false ,false};
+ // return defaults;
+ // }
+
+ @Override
+ public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points)
+ throws Exception {
+
+ this.points = points;
+ this.clustering = clustering;
+ numPoints = points.size();
+ numFClusters = clustering.size();
+ numDims = points.get(0).numAttributes() - 1;
+
+ int totalRedundancy = 0;
+ int trueCoverage = 0;
+ int totalCoverage = 0;
+
+ int numNoise = 0;
+ for (int p = 0; p < numPoints; p++) {
+ int coverage = 0;
+ for (int c = 0; c < numFClusters; c++) {
+ // contained in cluster c?
+ if (clustering.get(c).getInclusionProbability(points.get(p)) >= pointInclusionProbThreshold) {
+ coverage++;
+ }
+ }
+
+ if (points.get(p).classValue() == -1) {
+ numNoise++;
+ }
+ else {
+ if (coverage > 0)
+ trueCoverage++;
+ }
+
+ if (coverage > 0)
+ totalCoverage++; // points covered by clustering (incl. noise)
+ if (coverage > 1)
+ totalRedundancy++; // include noise
}
+ addValue("numCluster", clustering.size());
+ addValue("numClasses", trueClustering.size());
+ addValue("Redundancy", ((double) totalRedundancy / (double) numPoints));
+ addValue("GPrecision", (totalCoverage == 0 ? 0 : ((double) trueCoverage / (double) (totalCoverage))));
+ addValue("GRecall", ((double) trueCoverage / (double) (numPoints - numNoise)));
+ // if(isEnabled(3)){
+ // addValue("Compactness", computeCompactness());
+ // }
+ // if(isEnabled(3)){
+ // addValue("Overlap", computeOverlap());
+ // }
+ }
- @Override
- protected String[] getNames() {
- //String[] names = {"GPrecision","GRecall","Redundancy","Overlap","numCluster","numClasses","Compactness"};
- return new String[]{"GPrecision","GRecall","Redundancy","numCluster","numClasses"};
+ private double computeOverlap() {
+ for (int c = 0; c < numFClusters; c++) {
+ if (!(clustering.get(c) instanceof SphereCluster)) {
+ System.out.println("Overlap only supports Sphere Cluster. Found: " + clustering.get(c).getClass());
+ return Double.NaN;
+ }
}
-// @Override
-// protected boolean[] getDefaultEnabled() {
-// boolean [] defaults = {false, false, false, false, false ,false};
-// return defaults;
-// }
+ boolean[] overlap = new boolean[numFClusters];
- @Override
- public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) throws Exception{
-
- this.points = points;
- this.clustering = clustering;
- numPoints = points.size();
- numFClusters = clustering.size();
- numDims = points.get(0).numAttributes()-1;
-
-
- int totalRedundancy = 0;
- int trueCoverage = 0;
- int totalCoverage = 0;
-
- int numNoise = 0;
- for (int p = 0; p < numPoints; p++) {
- int coverage = 0;
- for (int c = 0; c < numFClusters; c++) {
- //contained in cluster c?
- if(clustering.get(c).getInclusionProbability(points.get(p)) >= pointInclusionProbThreshold){
- coverage++;
- }
- }
-
- if(points.get(p).classValue()==-1){
- numNoise++;
- }
- else{
- if(coverage>0) trueCoverage++;
- }
-
- if(coverage>0) totalCoverage++; //points covered by clustering (incl. noise)
- if(coverage>1) totalRedundancy++; //include noise
+ for (int c0 = 0; c0 < numFClusters; c0++) {
+ if (overlap[c0])
+ continue;
+ SphereCluster s0 = (SphereCluster) clustering.get(c0);
+ for (int c1 = c0; c1 < clustering.size(); c1++) {
+ if (c1 == c0)
+ continue;
+ SphereCluster s1 = (SphereCluster) clustering.get(c1);
+ if (s0.overlapRadiusDegree(s1) > 0) {
+ overlap[c0] = overlap[c1] = true;
}
-
- addValue("numCluster", clustering.size());
- addValue("numClasses", trueClustering.size());
- addValue("Redundancy", ((double)totalRedundancy/(double)numPoints));
- addValue("GPrecision", (totalCoverage==0?0:((double)trueCoverage/(double)(totalCoverage))));
- addValue("GRecall", ((double)trueCoverage/(double)(numPoints-numNoise)));
-// if(isEnabled(3)){
-// addValue("Compactness", computeCompactness());
-// }
-// if(isEnabled(3)){
-// addValue("Overlap", computeOverlap());
-// }
+ }
}
- private double computeOverlap(){
- for (int c = 0; c < numFClusters; c++) {
- if(!(clustering.get(c) instanceof SphereCluster)){
- System.out.println("Overlap only supports Sphere Cluster. Found: "+clustering.get(c).getClass());
- return Double.NaN;
- }
- }
-
- boolean[] overlap = new boolean[numFClusters];
-
- for (int c0 = 0; c0 < numFClusters; c0++) {
- if(overlap[c0]) continue;
- SphereCluster s0 = (SphereCluster)clustering.get(c0);
- for (int c1 = c0; c1 < clustering.size(); c1++) {
- if(c1 == c0) continue;
- SphereCluster s1 = (SphereCluster)clustering.get(c1);
- if(s0.overlapRadiusDegree(s1) > 0){
- overlap[c0] = overlap[c1] = true;
- }
- }
- }
-
- double totalOverlap = 0;
- for (int c0 = 0; c0 < numFClusters; c0++) {
- if(overlap[c0])
- totalOverlap++;
- }
-
-// if(totalOverlap/(double)numFClusters > .8) RunVisualizer.pause();
- if(numFClusters>0) totalOverlap/=(double)numFClusters;
- return totalOverlap;
+ double totalOverlap = 0;
+ for (int c0 = 0; c0 < numFClusters; c0++) {
+ if (overlap[c0])
+ totalOverlap++;
}
+ // if(totalOverlap/(double)numFClusters > .8) RunVisualizer.pause();
+ if (numFClusters > 0)
+ totalOverlap /= (double) numFClusters;
+ return totalOverlap;
+ }
- private double computeCompactness(){
- if(numFClusters == 0) return 0;
- for (int c = 0; c < numFClusters; c++) {
- if(!(clustering.get(c) instanceof SphereCluster)){
- System.out.println("Compactness only supports Sphere Cluster. Found: "+clustering.get(c).getClass());
- return Double.NaN;
- }
- }
-
- //TODO weight radius by number of dimensions
- double totalCompactness = 0;
- for (int c = 0; c < numFClusters; c++) {
- ArrayList<Instance> containedPoints = new ArrayList<Instance>();
- for (int p = 0; p < numPoints; p++) {
- //p in c
- if(clustering.get(c).getInclusionProbability(points.get(p)) >= pointInclusionProbThreshold){
- containedPoints.add(points.get(p));
- }
- }
- double compactness = 0;
- if(containedPoints.size()>1){
- //cluster not empty
- SphereCluster minEnclosingCluster = new SphereCluster(containedPoints, numDims);
- double minRadius = minEnclosingCluster.getRadius();
- double cfRadius = ((SphereCluster)clustering.get(c)).getRadius();
- if(Math.abs(minRadius-cfRadius) < 0.1e-10){
- compactness = 1;
- }
- else
- if(minRadius < cfRadius)
- compactness = minRadius/cfRadius;
- else{
- System.out.println("Optimal radius bigger then real one ("+(cfRadius-minRadius)+"), this is really wrong");
- compactness = 1;
- }
- }
- else{
- double cfRadius = ((SphereCluster)clustering.get(c)).getRadius();
- if(cfRadius==0) compactness = 1;
- }
-
- //weight by weight of cluster???
- totalCompactness+=compactness;
- clustering.get(c).setMeasureValue("Compactness", Double.toString(compactness));
- }
- return (totalCompactness/numFClusters);
+ private double computeCompactness() {
+ if (numFClusters == 0)
+ return 0;
+ for (int c = 0; c < numFClusters; c++) {
+ if (!(clustering.get(c) instanceof SphereCluster)) {
+ System.out.println("Compactness only supports Sphere Cluster. Found: " + clustering.get(c).getClass());
+ return Double.NaN;
+ }
}
+ // TODO weight radius by number of dimensions
+ double totalCompactness = 0;
+ for (int c = 0; c < numFClusters; c++) {
+ ArrayList<Instance> containedPoints = new ArrayList<Instance>();
+ for (int p = 0; p < numPoints; p++) {
+ // p in c
+ if (clustering.get(c).getInclusionProbability(points.get(p)) >= pointInclusionProbThreshold) {
+ containedPoints.add(points.get(p));
+ }
+ }
+ double compactness = 0;
+ if (containedPoints.size() > 1) {
+ // cluster not empty
+ SphereCluster minEnclosingCluster = new SphereCluster(containedPoints, numDims);
+ double minRadius = minEnclosingCluster.getRadius();
+ double cfRadius = ((SphereCluster) clustering.get(c)).getRadius();
+ if (Math.abs(minRadius - cfRadius) < 0.1e-10) {
+ compactness = 1;
+ }
+ else if (minRadius < cfRadius)
+ compactness = minRadius / cfRadius;
+ else {
+ System.out.println("Optimal radius bigger then real one (" + (cfRadius - minRadius)
+ + "), this is really wrong");
+ compactness = 1;
+ }
+ }
+ else {
+ double cfRadius = ((SphereCluster) clustering.get(c)).getRadius();
+ if (cfRadius == 0)
+ compactness = 1;
+ }
+
+ // weight by weight of cluster???
+ totalCompactness += compactness;
+ clustering.get(c).setMeasureValue("Compactness", Double.toString(compactness));
+ }
+ return (totalCompactness / numFClusters);
+ }
}
-
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/SSQ.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/SSQ.java
index 4f57788..ac25888 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/SSQ.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/SSQ.java
@@ -28,69 +28,70 @@
public class SSQ extends MeasureCollection {
- public SSQ() {
- super();
- }
+ public SSQ() {
+ super();
+ }
- @Override
- public String[] getNames() {
- return new String[]{"SSQ"};
- }
+ @Override
+ public String[] getNames() {
+ return new String[] { "SSQ" };
+ }
- @Override
- protected boolean[] getDefaultEnabled() {
- return new boolean[]{false};
- }
+ @Override
+ protected boolean[] getDefaultEnabled() {
+ return new boolean[] { false };
+ }
- // TODO Work on this later
- //@Override
- public void evaluateClusteringSamoa(Clustering clustering,
- Clustering trueClsutering, ArrayList<Instance> points) {
- double sum = 0.0;
- for (Instance point : points) {
- // don't include noise
- if (point.classValue() == -1) {
- continue;
- }
+ // TODO Work on this later
+ // @Override
+ public void evaluateClusteringSamoa(Clustering clustering,
+ Clustering trueClsutering, ArrayList<Instance> points) {
+ double sum = 0.0;
+ for (Instance point : points) {
+ // don't include noise
+ if (point.classValue() == -1) {
+ continue;
+ }
- double minDistance = Double.MAX_VALUE;
- for (int c = 0; c < clustering.size(); c++) {
- double distance = 0.0;
- double[] center = clustering.get(c).getCenter();
- for (int i = 0; i < center.length; i++) {
- double d = point.value(i) - center[i];
- distance += d * d;
- }
- minDistance = Math.min(distance, minDistance);
- }
-
- sum += minDistance;
+ double minDistance = Double.MAX_VALUE;
+ for (int c = 0; c < clustering.size(); c++) {
+ double distance = 0.0;
+ double[] center = clustering.get(c).getCenter();
+ for (int i = 0; i < center.length; i++) {
+ double d = point.value(i) - center[i];
+ distance += d * d;
}
+ minDistance = Math.min(distance, minDistance);
+ }
- addValue(0, sum);
+ sum += minDistance;
}
- @Override
- public void evaluateClustering(Clustering clustering, Clustering trueClsutering, ArrayList<DataPoint> points) {
- double sum = 0.0;
- for (int p = 0; p < points.size(); p++) {
- //don't include noise
- if(points.get(p).classValue()==-1) continue;
+ addValue(0, sum);
+ }
- double minDistance = Double.MAX_VALUE;
- for (int c = 0; c < clustering.size(); c++) {
- double distance = 0.0;
- double[] center = clustering.get(c).getCenter();
- for (int i = 0; i < center.length; i++) {
- double d = points.get(p).value(i) - center[i];
- distance += d * d;
- }
- minDistance = Math.min(distance, minDistance);
- }
-
- sum+=minDistance;
+ @Override
+ public void evaluateClustering(Clustering clustering, Clustering trueClsutering, ArrayList<DataPoint> points) {
+ double sum = 0.0;
+ for (int p = 0; p < points.size(); p++) {
+ // don't include noise
+ if (points.get(p).classValue() == -1)
+ continue;
+
+ double minDistance = Double.MAX_VALUE;
+ for (int c = 0; c < clustering.size(); c++) {
+ double distance = 0.0;
+ double[] center = clustering.get(c).getCenter();
+ for (int i = 0; i < center.length; i++) {
+ double d = points.get(p).value(i) - center[i];
+ distance += d * d;
}
-
- addValue(0,sum);
+ minDistance = Math.min(distance, minDistance);
+ }
+
+ sum += minDistance;
}
+
+ addValue(0, sum);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/Separation.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/Separation.java
index 25534a6..1e3072b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/Separation.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/Separation.java
@@ -31,88 +31,88 @@
public class Separation extends MeasureCollection {
- public Separation() {
- super();
+ public Separation() {
+ super();
+ }
+
+ @Override
+ protected String[] getNames() {
+ return new String[] { "BSS", "BSS-GT", "BSS-Ratio" };
+ }
+
+ // @Override
+ public void evaluateClusteringSamoa(Clustering clustering,
+ Clustering trueClustering, ArrayList<Instance> points)
+ throws Exception {
+
+ double BSS_GT = 1.0;
+ double BSS;
+ int dimension = points.get(0).numAttributes() - 1;
+ SphereCluster sc = new SphereCluster(points, dimension);
+
+ // DO INTERNAL EVALUATION
+ // clustering.getClustering().get(0).getCenter();
+
+ BSS = getBSS(clustering, sc.getCenter());
+
+ if (trueClustering != null) {
+ List<Instance> listInstances = new ArrayList<>();
+ for (Cluster c : trueClustering.getClustering()) {
+ DenseInstance inst = new DenseInstance(c.getWeight(), c.getCenter());
+ listInstances.add(inst);
+ }
+ SphereCluster gt = new SphereCluster(listInstances, dimension);
+ BSS_GT = getBSS(trueClustering, gt.getCenter());
}
- @Override
- protected String[] getNames() {
- return new String[]{"BSS", "BSS-GT", "BSS-Ratio"};
+ addValue("BSS", BSS);
+ addValue("BSS-GT", BSS_GT);
+ addValue("BSS-Ratio", BSS / BSS_GT);
+
+ }
+
+ private double getBSS(Clustering clustering, double[] mean) {
+ double bss = 0.0;
+ for (int i = 0; i < clustering.size(); i++) {
+ double weight = clustering.get(i).getWeight();
+ double sum = 0.0;
+ for (int j = 0; j < mean.length; j++) {
+ sum += Math.pow((mean[j] - clustering.get(i).getCenter()[j]), 2);
+ }
+ bss += weight * sum;
}
- //@Override
- public void evaluateClusteringSamoa(Clustering clustering,
- Clustering trueClustering, ArrayList<Instance> points)
- throws Exception {
+ return bss;
+ }
- double BSS_GT = 1.0;
- double BSS;
- int dimension = points.get(0).numAttributes() - 1;
- SphereCluster sc = new SphereCluster(points, dimension);
+ @Override
+ protected void evaluateClustering(Clustering clustering,
+ Clustering trueClustering, ArrayList<DataPoint> points)
+ throws Exception {
+ double BSS_GT = 1.0;
+ double BSS;
+ int dimension = points.get(0).numAttributes() - 1;
+ SphereCluster sc = new SphereCluster(points, dimension);
- // DO INTERNAL EVALUATION
- //clustering.getClustering().get(0).getCenter();
+ // DO INTERNAL EVALUATION
+ // clustering.getClustering().get(0).getCenter();
- BSS = getBSS(clustering, sc.getCenter());
+ BSS = getBSS(clustering, sc.getCenter());
- if (trueClustering != null) {
- List<Instance> listInstances = new ArrayList<>();
- for (Cluster c : trueClustering.getClustering()) {
- DenseInstance inst = new DenseInstance(c.getWeight(), c.getCenter());
- listInstances.add(inst);
- }
- SphereCluster gt = new SphereCluster(listInstances, dimension);
- BSS_GT = getBSS(trueClustering, gt.getCenter());
- }
-
- addValue("BSS", BSS);
- addValue("BSS-GT", BSS_GT);
- addValue("BSS-Ratio", BSS / BSS_GT);
-
+ if (trueClustering != null) {
+ String s = "";
+ List<Instance> listInstances = new ArrayList<>();
+ for (Cluster c : trueClustering.getClustering()) {
+ DenseInstance inst = new DenseInstance(c.getWeight(), c.getCenter());
+ listInstances.add(inst);
+ s += " " + c.getWeight();
+ }
+ SphereCluster gt = new SphereCluster(listInstances, dimension);
+ BSS_GT = getBSS(trueClustering, gt.getCenter());
}
- private double getBSS(Clustering clustering, double[] mean) {
- double bss = 0.0;
- for (int i = 0; i < clustering.size(); i++) {
- double weight = clustering.get(i).getWeight();
- double sum = 0.0;
- for (int j = 0; j < mean.length; j++) {
- sum += Math.pow((mean[j] - clustering.get(i).getCenter()[j]), 2);
- }
- bss += weight * sum;
- }
-
- return bss;
- }
-
- @Override
- protected void evaluateClustering(Clustering clustering,
- Clustering trueClustering, ArrayList<DataPoint> points)
- throws Exception {
- double BSS_GT = 1.0;
- double BSS;
- int dimension = points.get(0).numAttributes() - 1;
- SphereCluster sc = new SphereCluster(points, dimension);
-
- // DO INTERNAL EVALUATION
- //clustering.getClustering().get(0).getCenter();
-
- BSS = getBSS(clustering, sc.getCenter());
-
- if (trueClustering != null) {
- String s = "";
- List<Instance> listInstances = new ArrayList<>();
- for (Cluster c : trueClustering.getClustering()) {
- DenseInstance inst = new DenseInstance(c.getWeight(), c.getCenter());
- listInstances.add(inst);
- s += " " + c.getWeight();
- }
- SphereCluster gt = new SphereCluster(listInstances, dimension);
- BSS_GT = getBSS(trueClustering, gt.getCenter());
- }
-
- addValue("BSS", BSS);
- addValue("BSS-GT", BSS_GT);
- addValue("BSS-Ratio", BSS / BSS_GT);
- }
+ addValue("BSS", BSS);
+ addValue("BSS-GT", BSS_GT);
+ addValue("BSS-Ratio", BSS / BSS_GT);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/SilhouetteCoefficient.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/SilhouetteCoefficient.java
index 6dee336..3740910 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/SilhouetteCoefficient.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/SilhouetteCoefficient.java
@@ -20,118 +20,116 @@
* #L%
*/
-
import com.yahoo.labs.samoa.moa.cluster.Cluster;
import com.yahoo.labs.samoa.moa.cluster.Clustering;
import com.yahoo.labs.samoa.moa.evaluation.MeasureCollection;
import com.yahoo.labs.samoa.moa.core.DataPoint;
import java.util.ArrayList;
+public class SilhouetteCoefficient extends MeasureCollection {
+ private double pointInclusionProbThreshold = 0.8;
-public class SilhouetteCoefficient extends MeasureCollection{
- private double pointInclusionProbThreshold = 0.8;
+ public SilhouetteCoefficient() {
+ super();
+ }
- public SilhouetteCoefficient() {
- super();
+ @Override
+ protected boolean[] getDefaultEnabled() {
+ return new boolean[] { false };
+ }
+
+ @Override
+ public String[] getNames() {
+ return new String[] { "SilhCoeff" };
+ }
+
+ public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) {
+ int numFCluster = clustering.size();
+
+ double[][] pointInclusionProbFC = new double[points.size()][numFCluster];
+ for (int p = 0; p < points.size(); p++) {
+ DataPoint point = points.get(p);
+ for (int fc = 0; fc < numFCluster; fc++) {
+ Cluster cl = clustering.get(fc);
+ pointInclusionProbFC[p][fc] = cl.getInclusionProbability(point);
+ }
}
- @Override
- protected boolean[] getDefaultEnabled() {
- return new boolean[]{false};
- }
+ double silhCoeff = 0.0;
+ int totalCount = 0;
+ for (int p = 0; p < points.size(); p++) {
+ DataPoint point = points.get(p);
+ ArrayList<Integer> ownClusters = new ArrayList<>();
+ for (int fc = 0; fc < numFCluster; fc++) {
+ if (pointInclusionProbFC[p][fc] > pointInclusionProbThreshold) {
+ ownClusters.add(fc);
+ }
+ }
- @Override
- public String[] getNames() {
- return new String[]{"SilhCoeff"};
- }
-
- public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) {
- int numFCluster = clustering.size();
-
- double [][] pointInclusionProbFC = new double[points.size()][numFCluster];
- for (int p = 0; p < points.size(); p++) {
- DataPoint point = points.get(p);
+ if (ownClusters.size() > 0) {
+ double[] distanceByClusters = new double[numFCluster];
+ int[] countsByClusters = new int[numFCluster];
+ // calculate averageDistance of p to all cluster
+ for (int p1 = 0; p1 < points.size(); p1++) {
+ DataPoint point1 = points.get(p1);
+ if (p1 != p && point1.classValue() != -1) {
for (int fc = 0; fc < numFCluster; fc++) {
- Cluster cl = clustering.get(fc);
- pointInclusionProbFC[p][fc] = cl.getInclusionProbability(point);
+ if (pointInclusionProbFC[p1][fc] > pointInclusionProbThreshold) {
+ double distance = distance(point, point1);
+ distanceByClusters[fc] += distance;
+ countsByClusters[fc]++;
+ }
}
+ }
}
- double silhCoeff = 0.0;
- int totalCount = 0;
- for (int p = 0; p < points.size(); p++) {
- DataPoint point = points.get(p);
- ArrayList<Integer> ownClusters = new ArrayList<>();
- for (int fc = 0; fc < numFCluster; fc++) {
- if(pointInclusionProbFC[p][fc] > pointInclusionProbThreshold){
- ownClusters.add(fc);
- }
- }
-
- if(ownClusters.size() > 0){
- double[] distanceByClusters = new double[numFCluster];
- int[] countsByClusters = new int[numFCluster];
- //calculate averageDistance of p to all cluster
- for (int p1 = 0; p1 < points.size(); p1++) {
- DataPoint point1 = points.get(p1);
- if(p1!= p && point1.classValue() != -1){
- for (int fc = 0; fc < numFCluster; fc++) {
- if(pointInclusionProbFC[p1][fc] > pointInclusionProbThreshold){
- double distance = distance(point, point1);
- distanceByClusters[fc]+=distance;
- countsByClusters[fc]++;
- }
- }
- }
- }
-
- //find closest OWN cluster as clusters might overlap
- double minAvgDistanceOwn = Double.MAX_VALUE;
- int minOwnIndex = -1;
- for (int fc : ownClusters) {
- double normDist = distanceByClusters[fc]/(double)countsByClusters[fc];
- if(normDist < minAvgDistanceOwn){// && pointInclusionProbFC[p][fc] > pointInclusionProbThreshold){
- minAvgDistanceOwn = normDist;
- minOwnIndex = fc;
- }
- }
-
-
- //find closest other (or other own) cluster
- double minAvgDistanceOther = Double.MAX_VALUE;
- for (int fc = 0; fc < numFCluster; fc++) {
- if(fc != minOwnIndex){
- double normDist = distanceByClusters[fc]/(double)countsByClusters[fc];
- if(normDist < minAvgDistanceOther){
- minAvgDistanceOther = normDist;
- }
- }
- }
-
- double silhP = (minAvgDistanceOther-minAvgDistanceOwn)/Math.max(minAvgDistanceOther, minAvgDistanceOwn);
- point.setMeasureValue("SC - own", minAvgDistanceOwn);
- point.setMeasureValue("SC - other", minAvgDistanceOther);
- point.setMeasureValue("SC", silhP);
-
- silhCoeff+=silhP;
- totalCount++;
- //System.out.println(point.getTimestamp()+" Silh "+silhP+" / "+avgDistanceOwn+" "+minAvgDistanceOther+" (C"+minIndex+")");
- }
+ // find closest OWN cluster as clusters might overlap
+ double minAvgDistanceOwn = Double.MAX_VALUE;
+ int minOwnIndex = -1;
+ for (int fc : ownClusters) {
+ double normDist = distanceByClusters[fc] / (double) countsByClusters[fc];
+ if (normDist < minAvgDistanceOwn) {// && pointInclusionProbFC[p][fc] >
+ // pointInclusionProbThreshold){
+ minAvgDistanceOwn = normDist;
+ minOwnIndex = fc;
+ }
}
- if(totalCount>0)
- silhCoeff/=(double)totalCount;
- //normalize from -1, 1 to 0,1
- silhCoeff = (silhCoeff+1)/2.0;
- addValue(0,silhCoeff);
+
+ // find closest other (or other own) cluster
+ double minAvgDistanceOther = Double.MAX_VALUE;
+ for (int fc = 0; fc < numFCluster; fc++) {
+ if (fc != minOwnIndex) {
+ double normDist = distanceByClusters[fc] / (double) countsByClusters[fc];
+ if (normDist < minAvgDistanceOther) {
+ minAvgDistanceOther = normDist;
+ }
+ }
+ }
+
+ double silhP = (minAvgDistanceOther - minAvgDistanceOwn) / Math.max(minAvgDistanceOther, minAvgDistanceOwn);
+ point.setMeasureValue("SC - own", minAvgDistanceOwn);
+ point.setMeasureValue("SC - other", minAvgDistanceOther);
+ point.setMeasureValue("SC", silhP);
+
+ silhCoeff += silhP;
+ totalCount++;
+ // System.out.println(point.getTimestamp()+" Silh "+silhP+" / "+avgDistanceOwn+" "+minAvgDistanceOther+" (C"+minIndex+")");
+ }
}
+ if (totalCount > 0)
+ silhCoeff /= (double) totalCount;
+ // normalize from -1, 1 to 0,1
+ silhCoeff = (silhCoeff + 1) / 2.0;
+ addValue(0, silhCoeff);
+ }
- private double distance(DataPoint inst1, DataPoint inst2){
- double distance = 0.0;
- int numDims = inst1.numAttributes();
- for (int i = 0; i < numDims; i++) {
- double d = inst1.value(i) - inst2.value(i);
- distance += d * d;
- }
- return Math.sqrt(distance);
+ private double distance(DataPoint inst1, DataPoint inst2) {
+ double distance = 0.0;
+ int numDims = inst1.numAttributes();
+ for (int i = 0; i < numDims; i++) {
+ double d = inst1.value(i) - inst2.value(i);
+ distance += d * d;
}
+ return Math.sqrt(distance);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/StatisticalCollection.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/StatisticalCollection.java
index 6fc7adc..9b5f866 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/StatisticalCollection.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/evaluation/measures/StatisticalCollection.java
@@ -28,159 +28,160 @@
import com.yahoo.labs.samoa.moa.evaluation.MeasureCollection;
import com.yahoo.labs.samoa.moa.evaluation.MembershipMatrix;
-public class StatisticalCollection extends MeasureCollection{
- private boolean debug = false;
+public class StatisticalCollection extends MeasureCollection {
+ private boolean debug = false;
- @Override
- protected String[] getNames() {
- //String[] names = {"van Dongen","Rand statistic", "C Index"};
- return new String[]{"van Dongen","Rand statistic"};
+ @Override
+ protected String[] getNames() {
+ // String[] names = {"van Dongen","Rand statistic", "C Index"};
+ return new String[] { "van Dongen", "Rand statistic" };
+ }
+
+ @Override
+ protected boolean[] getDefaultEnabled() {
+ return new boolean[] { false, false };
+ }
+
+ @Override
+ public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points)
+ throws Exception {
+
+ MembershipMatrix mm = new MembershipMatrix(clustering, points);
+ int numClasses = mm.getNumClasses();
+ int numCluster = clustering.size() + 1;
+ int n = mm.getTotalEntries();
+
+ double dongenMaxFC = 0;
+ double dongenMaxSumFC = 0;
+ for (int i = 0; i < numCluster; i++) {
+ double max = 0;
+ for (int j = 0; j < numClasses; j++) {
+ if (mm.getClusterClassWeight(i, j) > max)
+ max = mm.getClusterClassWeight(i, j);
+ }
+ dongenMaxFC += max;
+ if (mm.getClusterSum(i) > dongenMaxSumFC)
+ dongenMaxSumFC = mm.getClusterSum(i);
}
- @Override
- protected boolean[] getDefaultEnabled() {
- return new boolean[]{false, false};
+ double dongenMaxHC = 0;
+ double dongenMaxSumHC = 0;
+ for (int j = 0; j < numClasses; j++) {
+ double max = 0;
+ for (int i = 0; i < numCluster; i++) {
+ if (mm.getClusterClassWeight(i, j) > max)
+ max = mm.getClusterClassWeight(i, j);
+ }
+ dongenMaxHC += max;
+ if (mm.getClassSum(j) > dongenMaxSumHC)
+ dongenMaxSumHC = mm.getClassSum(j);
}
- @Override
- public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) throws Exception {
+ double dongen = (dongenMaxFC + dongenMaxHC) / (2 * n);
+ // normalized dongen
+ // double dongen = 1-(2*n - dongenMaxFC - dongenMaxHC)/(2*n - dongenMaxSumFC
+ // - dongenMaxSumHC);
+ if (debug)
+ System.out.println("Dongen HC:" + dongenMaxHC + " FC:" + dongenMaxFC + " Total:" + dongen + " n " + n);
+ addValue("van Dongen", dongen);
- MembershipMatrix mm = new MembershipMatrix(clustering, points);
- int numClasses = mm.getNumClasses();
- int numCluster = clustering.size()+1;
- int n = mm.getTotalEntries();
-
- double dongenMaxFC = 0;
- double dongenMaxSumFC = 0;
- for (int i = 0; i < numCluster; i++){
- double max = 0;
- for (int j = 0; j < numClasses; j++) {
- if(mm.getClusterClassWeight(i, j)>max) max = mm.getClusterClassWeight(i, j);
- }
- dongenMaxFC+=max;
- if(mm.getClusterSum(i)>dongenMaxSumFC) dongenMaxSumFC = mm.getClusterSum(i);
- }
-
- double dongenMaxHC = 0;
- double dongenMaxSumHC = 0;
- for (int j = 0; j < numClasses; j++) {
- double max = 0;
- for (int i = 0; i < numCluster; i++){
- if(mm.getClusterClassWeight(i, j)>max) max = mm.getClusterClassWeight(i, j);
- }
- dongenMaxHC+=max;
- if(mm.getClassSum(j)>dongenMaxSumHC) dongenMaxSumHC = mm.getClassSum(j);
- }
-
- double dongen = (dongenMaxFC + dongenMaxHC)/(2*n);
- //normalized dongen
- //double dongen = 1-(2*n - dongenMaxFC - dongenMaxHC)/(2*n - dongenMaxSumFC - dongenMaxSumHC);
- if(debug)
- System.out.println("Dongen HC:"+dongenMaxHC+" FC:"+dongenMaxFC+" Total:"+dongen+" n "+n);
-
- addValue("van Dongen", dongen);
-
-
- //Rand index
- //http://www.cais.ntu.edu.sg/~qihe/menu4.html
- double m1 = 0;
- for (int j = 0; j < numClasses; j++) {
- double v = mm.getClassSum(j);
- m1+= v*(v-1)/2.0;
- }
- double m2 = 0;
- for (int i = 0; i < numCluster; i++){
- double v = mm.getClusterSum(i);
- m2+= v*(v-1)/2.0;
- }
-
- double m = 0;
- for (int i = 0; i < numCluster; i++){
- for (int j = 0; j < numClasses; j++) {
- double v = mm.getClusterClassWeight(i, j);
- m+= v*(v-1)/2.0;
- }
- }
- double M = n*(n-1)/2.0;
- double rand = (M - m1 - m2 +2*m)/M;
- //normalized rand
- //double rand = (m - m1*m2/M)/(m1/2.0 + m2/2.0 - m1*m2/M);
-
- addValue("Rand statistic", rand);
-
-
- //addValue("C Index",cindex(clustering, points));
+ // Rand index
+ // http://www.cais.ntu.edu.sg/~qihe/menu4.html
+ double m1 = 0;
+ for (int j = 0; j < numClasses; j++) {
+ double v = mm.getClassSum(j);
+ m1 += v * (v - 1) / 2.0;
+ }
+ double m2 = 0;
+ for (int i = 0; i < numCluster; i++) {
+ double v = mm.getClusterSum(i);
+ m2 += v * (v - 1) / 2.0;
}
+ double m = 0;
+ for (int i = 0; i < numCluster; i++) {
+ for (int j = 0; j < numClasses; j++) {
+ double v = mm.getClusterClassWeight(i, j);
+ m += v * (v - 1) / 2.0;
+ }
+ }
+ double M = n * (n - 1) / 2.0;
+ double rand = (M - m1 - m2 + 2 * m) / M;
+ // normalized rand
+ // double rand = (m - m1*m2/M)/(m1/2.0 + m2/2.0 - m1*m2/M);
+ addValue("Rand statistic", rand);
- public double cindex(Clustering clustering, ArrayList<DataPoint> points){
- int numClusters = clustering.size();
- double withinClustersDistance = 0;
- int numDistancesWithin = 0;
- double numDistances = 0;
+ // addValue("C Index",cindex(clustering, points));
+ }
- //double[] withinClusters = new double[numClusters];
- double[] minWithinClusters = new double[numClusters];
- double[] maxWithinClusters = new double[numClusters];
- ArrayList<Integer>[] pointsInClusters = new ArrayList[numClusters];
- for (int c = 0; c < numClusters; c++) {
- pointsInClusters[c] = new ArrayList<>();
- minWithinClusters[c] = Double.MAX_VALUE;
- maxWithinClusters[c] = Double.MIN_VALUE;
- }
+ public double cindex(Clustering clustering, ArrayList<DataPoint> points) {
+ int numClusters = clustering.size();
+ double withinClustersDistance = 0;
+ int numDistancesWithin = 0;
+ double numDistances = 0;
- for (int p = 0; p < points.size(); p++) {
- for (int c = 0; c < clustering.size(); c++) {
- if(clustering.get(c).getInclusionProbability(points.get(p)) > 0.8){
- pointsInClusters[c].add(p);
- numDistances++;
- }
- }
- }
-
- //calc within cluster distances + min and max values
- for (int c = 0; c < numClusters; c++) {
- int numDistancesInC = 0;
- ArrayList<Integer> pointsInC = pointsInClusters[c];
- for (int p = 0; p < pointsInC.size(); p++) {
- DataPoint point = points.get(pointsInC.get(p));
- for (int p1 = p+1; p1 < pointsInC.size(); p1++) {
- numDistancesWithin++;
- numDistancesInC++;
- DataPoint point1 = points.get(pointsInC.get(p1));
- double dist = point.getDistance(point1);
- withinClustersDistance+=dist;
- if(minWithinClusters[c] > dist) minWithinClusters[c] = dist;
- if(maxWithinClusters[c] < dist) maxWithinClusters[c] = dist;
- }
- }
- }
-
- double minWithin = Double.MAX_VALUE;
- double maxWithin = Double.MIN_VALUE;
- for (int c = 0; c < numClusters; c++) {
- if(minWithinClusters[c] < minWithin)
- minWithin = minWithinClusters[c];
- if(maxWithinClusters[c] > maxWithin)
- maxWithin = maxWithinClusters[c];
- }
-
- double cindex = 0;
- if(numDistancesWithin != 0){
- double meanWithinClustersDistance = withinClustersDistance/numDistancesWithin;
- cindex = (meanWithinClustersDistance - minWithin)/(maxWithin-minWithin);
- }
-
-
- if(debug){
- System.out.println("Min:"+Arrays.toString(minWithinClusters));
- System.out.println("Max:"+Arrays.toString(maxWithinClusters));
- System.out.println("totalWithin:"+numDistancesWithin);
- }
- return cindex;
+ // double[] withinClusters = new double[numClusters];
+ double[] minWithinClusters = new double[numClusters];
+ double[] maxWithinClusters = new double[numClusters];
+ ArrayList<Integer>[] pointsInClusters = new ArrayList[numClusters];
+ for (int c = 0; c < numClusters; c++) {
+ pointsInClusters[c] = new ArrayList<>();
+ minWithinClusters[c] = Double.MAX_VALUE;
+ maxWithinClusters[c] = Double.MIN_VALUE;
}
+ for (int p = 0; p < points.size(); p++) {
+ for (int c = 0; c < clustering.size(); c++) {
+ if (clustering.get(c).getInclusionProbability(points.get(p)) > 0.8) {
+ pointsInClusters[c].add(p);
+ numDistances++;
+ }
+ }
+ }
+
+ // calc within cluster distances + min and max values
+ for (int c = 0; c < numClusters; c++) {
+ int numDistancesInC = 0;
+ ArrayList<Integer> pointsInC = pointsInClusters[c];
+ for (int p = 0; p < pointsInC.size(); p++) {
+ DataPoint point = points.get(pointsInC.get(p));
+ for (int p1 = p + 1; p1 < pointsInC.size(); p1++) {
+ numDistancesWithin++;
+ numDistancesInC++;
+ DataPoint point1 = points.get(pointsInC.get(p1));
+ double dist = point.getDistance(point1);
+ withinClustersDistance += dist;
+ if (minWithinClusters[c] > dist)
+ minWithinClusters[c] = dist;
+ if (maxWithinClusters[c] < dist)
+ maxWithinClusters[c] = dist;
+ }
+ }
+ }
+
+ double minWithin = Double.MAX_VALUE;
+ double maxWithin = Double.MIN_VALUE;
+ for (int c = 0; c < numClusters; c++) {
+ if (minWithinClusters[c] < minWithin)
+ minWithin = minWithinClusters[c];
+ if (maxWithinClusters[c] > maxWithin)
+ maxWithin = maxWithinClusters[c];
+ }
+
+ double cindex = 0;
+ if (numDistancesWithin != 0) {
+ double meanWithinClustersDistance = withinClustersDistance / numDistancesWithin;
+ cindex = (meanWithinClustersDistance - minWithin) / (maxWithin - minWithin);
+ }
+
+ if (debug) {
+ System.out.println("Min:" + Arrays.toString(minWithinClusters));
+ System.out.println("Max:" + Arrays.toString(maxWithinClusters));
+ System.out.println("totalWithin:" + numDistancesWithin);
+ }
+ return cindex;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldContentEvent.java
index 5e86cb0..82052fd 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldContentEvent.java
@@ -27,43 +27,43 @@
*/
public class HelloWorldContentEvent implements ContentEvent {
- private static final long serialVersionUID = -2406968925730298156L;
- private final boolean isLastEvent;
- private final int helloWorldData;
+ private static final long serialVersionUID = -2406968925730298156L;
+ private final boolean isLastEvent;
+ private final int helloWorldData;
- public HelloWorldContentEvent(int helloWorldData, boolean isLastEvent) {
- this.isLastEvent = isLastEvent;
- this.helloWorldData = helloWorldData;
- }
-
- /*
- * No-argument constructor for Kryo
- */
- public HelloWorldContentEvent() {
- this(0,false);
- }
+ public HelloWorldContentEvent(int helloWorldData, boolean isLastEvent) {
+ this.isLastEvent = isLastEvent;
+ this.helloWorldData = helloWorldData;
+ }
- @Override
- public String getKey() {
- return null;
- }
+ /*
+ * No-argument constructor for Kryo
+ */
+ public HelloWorldContentEvent() {
+ this(0, false);
+ }
- @Override
- public void setKey(String str) {
- // do nothing, it's key-less content event
- }
+ @Override
+ public String getKey() {
+ return null;
+ }
- @Override
- public boolean isLastEvent() {
- return isLastEvent;
- }
+ @Override
+ public void setKey(String str) {
+ // do nothing, it's key-less content event
+ }
- public int getHelloWorldData() {
- return helloWorldData;
- }
+ @Override
+ public boolean isLastEvent() {
+ return isLastEvent;
+ }
- @Override
- public String toString() {
- return "HelloWorldContentEvent [helloWorldData=" + helloWorldData + "]";
- }
+ public int getHelloWorldData() {
+ return helloWorldData;
+ }
+
+ @Override
+ public String toString() {
+ return "HelloWorldContentEvent [helloWorldData=" + helloWorldData + "]";
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldDestinationProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldDestinationProcessor.java
index e22c0fe..3d8aac7 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldDestinationProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldDestinationProcessor.java
@@ -24,26 +24,27 @@
import com.yahoo.labs.samoa.core.Processor;
/**
- * Example {@link Processor} that simply prints the received events to standard output.
+ * Example {@link Processor} that simply prints the received events to standard
+ * output.
*/
public class HelloWorldDestinationProcessor implements Processor {
- private static final long serialVersionUID = -6042613438148776446L;
- private int processorId;
+ private static final long serialVersionUID = -6042613438148776446L;
+ private int processorId;
- @Override
- public boolean process(ContentEvent event) {
- System.out.println(processorId + ": " + event);
- return true;
- }
+ @Override
+ public boolean process(ContentEvent event) {
+ System.out.println(processorId + ": " + event);
+ return true;
+ }
- @Override
- public void onCreate(int id) {
- this.processorId = id;
- }
+ @Override
+ public void onCreate(int id) {
+ this.processorId = id;
+ }
- @Override
- public Processor newProcessor(Processor p) {
- return new HelloWorldDestinationProcessor();
- }
+ @Override
+ public Processor newProcessor(Processor p) {
+ return new HelloWorldDestinationProcessor();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldSourceProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldSourceProcessor.java
index a37201f..1f4517d 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldSourceProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldSourceProcessor.java
@@ -31,45 +31,45 @@
*/
public class HelloWorldSourceProcessor implements EntranceProcessor {
- private static final long serialVersionUID = 6212296305865604747L;
- private Random rnd;
- private final long maxInst;
- private long count;
+ private static final long serialVersionUID = 6212296305865604747L;
+ private Random rnd;
+ private final long maxInst;
+ private long count;
- public HelloWorldSourceProcessor(long maxInst) {
- this.maxInst = maxInst;
- }
+ public HelloWorldSourceProcessor(long maxInst) {
+ this.maxInst = maxInst;
+ }
- @Override
- public boolean process(ContentEvent event) {
- // do nothing, API will be refined further
- return false;
- }
+ @Override
+ public boolean process(ContentEvent event) {
+ // do nothing, API will be refined further
+ return false;
+ }
- @Override
- public void onCreate(int id) {
- rnd = new Random(id);
- }
+ @Override
+ public void onCreate(int id) {
+ rnd = new Random(id);
+ }
- @Override
- public Processor newProcessor(Processor p) {
- HelloWorldSourceProcessor hwsp = (HelloWorldSourceProcessor) p;
- return new HelloWorldSourceProcessor(hwsp.maxInst);
- }
+ @Override
+ public Processor newProcessor(Processor p) {
+ HelloWorldSourceProcessor hwsp = (HelloWorldSourceProcessor) p;
+ return new HelloWorldSourceProcessor(hwsp.maxInst);
+ }
- @Override
- public boolean isFinished() {
- return count >= maxInst;
- }
-
- @Override
- public boolean hasNext() {
- return count < maxInst;
- }
+ @Override
+ public boolean isFinished() {
+ return count >= maxInst;
+ }
- @Override
- public ContentEvent nextEvent() {
- count++;
- return new HelloWorldContentEvent(rnd.nextInt(), false);
- }
+ @Override
+ public boolean hasNext() {
+ return count < maxInst;
+ }
+
+ @Override
+ public ContentEvent nextEvent() {
+ count++;
+ return new HelloWorldContentEvent(rnd.nextInt(), false);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldTask.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldTask.java
index e6658f1..2c8e36c 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldTask.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/examples/HelloWorldTask.java
@@ -36,63 +36,67 @@
import com.yahoo.labs.samoa.topology.TopologyBuilder;
/**
- * Example {@link Task} in SAMOA. This task simply sends events from a source {@link HelloWorldSourceProcessor} to a destination
- * {@link HelloWorldDestinationProcessor}. The events are random integers generated by the source and encapsulated in a {@link HelloWorldContentEvent}. The
- * destination prints the content of the event to standard output, prepended by the processor id.
+ * Example {@link Task} in SAMOA. This task simply sends events from a source
+ * {@link HelloWorldSourceProcessor} to a destination
+ * {@link HelloWorldDestinationProcessor}. The events are random integers
+ * generated by the source and encapsulated in a {@link HelloWorldContentEvent}.
+ * The destination prints the content of the event to standard output, prepended
+ * by the processor id.
*
- * The task has 2 main options: the number of events the source will generate (-i) and the parallelism level of the destination (-p).
+ * The task has 2 main options: the number of events the source will generate
+ * (-i) and the parallelism level of the destination (-p).
*/
public class HelloWorldTask implements Task, Configurable {
- private static final long serialVersionUID = -5134935141154021352L;
- private static Logger logger = LoggerFactory.getLogger(HelloWorldTask.class);
+ private static final long serialVersionUID = -5134935141154021352L;
+ private static Logger logger = LoggerFactory.getLogger(HelloWorldTask.class);
- /** The topology builder for the task. */
- private TopologyBuilder builder;
- /** The topology that will be created for the task */
- private Topology helloWorldTopology;
+ /** The topology builder for the task. */
+ private TopologyBuilder builder;
+ /** The topology that will be created for the task */
+ private Topology helloWorldTopology;
- public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i',
- "Maximum number of instances to generate (-1 = no limit).", 1000000, -1, Integer.MAX_VALUE);
+ public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i',
+ "Maximum number of instances to generate (-1 = no limit).", 1000000, -1, Integer.MAX_VALUE);
- public IntOption helloWorldParallelismOption = new IntOption("parallelismOption", 'p',
- "Number of destination Processors", 1, 1, Integer.MAX_VALUE);
+ public IntOption helloWorldParallelismOption = new IntOption("parallelismOption", 'p',
+ "Number of destination Processors", 1, 1, Integer.MAX_VALUE);
- public StringOption evaluationNameOption = new StringOption("evaluationName", 'n',
- "Identifier of the evaluation", "HelloWorldTask" + new SimpleDateFormat("yyyyMMddHHmmss").format(new Date()));
+ public StringOption evaluationNameOption = new StringOption("evaluationName", 'n',
+ "Identifier of the evaluation", "HelloWorldTask" + new SimpleDateFormat("yyyyMMddHHmmss").format(new Date()));
- @Override
- public void init() {
- // create source EntranceProcessor
- /* The event source for the topology. Implements EntranceProcessor */
- HelloWorldSourceProcessor sourceProcessor = new HelloWorldSourceProcessor(instanceLimitOption.getValue());
- builder.addEntranceProcessor(sourceProcessor);
+ @Override
+ public void init() {
+ // create source EntranceProcessor
+ /* The event source for the topology. Implements EntranceProcessor */
+ HelloWorldSourceProcessor sourceProcessor = new HelloWorldSourceProcessor(instanceLimitOption.getValue());
+ builder.addEntranceProcessor(sourceProcessor);
- // create Stream
- Stream stream = builder.createStream(sourceProcessor);
+ // create Stream
+ Stream stream = builder.createStream(sourceProcessor);
- // create destination Processor
- /* The event sink for the topology. Implements Processor */
- HelloWorldDestinationProcessor destProcessor = new HelloWorldDestinationProcessor();
- builder.addProcessor(destProcessor, helloWorldParallelismOption.getValue());
- builder.connectInputShuffleStream(stream, destProcessor);
+ // create destination Processor
+ /* The event sink for the topology. Implements Processor */
+ HelloWorldDestinationProcessor destProcessor = new HelloWorldDestinationProcessor();
+ builder.addProcessor(destProcessor, helloWorldParallelismOption.getValue());
+ builder.connectInputShuffleStream(stream, destProcessor);
- // build the topology
- helloWorldTopology = builder.build();
- logger.debug("Successfully built the topology");
- }
+ // build the topology
+ helloWorldTopology = builder.build();
+ logger.debug("Successfully built the topology");
+ }
- @Override
- public Topology getTopology() {
- return helloWorldTopology;
- }
+ @Override
+ public Topology getTopology() {
+ return helloWorldTopology;
+ }
- @Override
- public void setFactory(ComponentFactory factory) {
- // will be removed when dynamic binding is implemented
- builder = new TopologyBuilder(factory);
- logger.debug("Successfully instantiating TopologyBuilder");
- builder.initTopology(evaluationNameOption.getValue());
- logger.debug("Successfully initializing SAMOA topology with name {}", evaluationNameOption.getValue());
- }
+ @Override
+ public void setFactory(ComponentFactory factory) {
+ // will be removed when dynamic binding is implemented
+ builder = new TopologyBuilder(factory);
+ logger.debug("Successfully instantiating TopologyBuilder");
+ builder.initTopology(evaluationNameOption.getValue());
+ logger.debug("Successfully initializing SAMOA topology with name {}", evaluationNameOption.getValue());
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/AdaptiveLearner.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/AdaptiveLearner.java
index 0986253..e465b7d 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/AdaptiveLearner.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/AdaptiveLearner.java
@@ -24,30 +24,30 @@
* License
*/
-
import com.yahoo.labs.samoa.moa.classifiers.core.driftdetection.ChangeDetector;
import com.yahoo.labs.samoa.topology.Stream;
/**
- * The Interface Adaptive Learner.
- * Initializing Classifier should initalize PI to connect the Classifier with the input stream
- * and initialize result stream so that other PI can connect to the classification result of this classifier
+ * The Interface Adaptive Learner. Initializing Classifier should initalize PI
+ * to connect the Classifier with the input stream and initialize result stream
+ * so that other PI can connect to the classification result of this classifier
*/
public interface AdaptiveLearner {
- /**
- * Gets the change detector item.
- *
- * @return the change detector item
- */
- public ChangeDetector getChangeDetector();
+ /**
+ * Gets the change detector item.
+ *
+ * @return the change detector item
+ */
+ public ChangeDetector getChangeDetector();
- /**
- * Sets the change detector item.
- *
- * @param cd the change detector item
- */
- public void setChangeDetector(ChangeDetector cd);
-
+ /**
+ * Sets the change detector item.
+ *
+ * @param cd
+ * the change detector item
+ */
+ public void setChangeDetector(ChangeDetector cd);
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/InstanceContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/InstanceContentEvent.java
index 91b1b7b..fd25736 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/InstanceContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/InstanceContentEvent.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.learners;
/*
@@ -29,8 +28,8 @@
import com.yahoo.labs.samoa.core.SerializableInstance;
import net.jcip.annotations.Immutable;
import com.yahoo.labs.samoa.instances.Instance;
-//import weka.core.Instance;
+//import weka.core.Instance;
/**
* The Class InstanceEvent.
@@ -38,170 +37,178 @@
@Immutable
final public class InstanceContentEvent implements ContentEvent {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -8620668863064613845L;
- private long instanceIndex;
- private int classifierIndex;
- private int evaluationIndex;
- private SerializableInstance instance;
- private boolean isTraining;
- private boolean isTesting;
- private boolean isLast = false;
-
- public InstanceContentEvent() {
-
- }
+ private static final long serialVersionUID = -8620668863064613845L;
+ private long instanceIndex;
+ private int classifierIndex;
+ private int evaluationIndex;
+ private SerializableInstance instance;
+ private boolean isTraining;
+ private boolean isTesting;
+ private boolean isLast = false;
- /**
- * Instantiates a new instance event.
- *
- * @param index the index
- * @param instance the instance
- * @param isTraining the is training
- */
- public InstanceContentEvent(long index, Instance instance,
- boolean isTraining, boolean isTesting) {
- if (instance != null) {
- this.instance = new SerializableInstance(instance);
- }
- this.instanceIndex = index;
- this.isTraining = isTraining;
- this.isTesting = isTesting;
- }
+ public InstanceContentEvent() {
- /**
- * Gets the single instance of InstanceEvent.
- *
- * @return the instance.
- */
- public Instance getInstance() {
- return instance;
- }
+ }
- /**
- * Gets the instance index.
- *
- * @return the index of the data vector.
- */
- public long getInstanceIndex() {
- return instanceIndex;
- }
+ /**
+ * Instantiates a new instance event.
+ *
+ * @param index
+ * the index
+ * @param instance
+ * the instance
+ * @param isTraining
+ * the is training
+ */
+ public InstanceContentEvent(long index, Instance instance,
+ boolean isTraining, boolean isTesting) {
+ if (instance != null) {
+ this.instance = new SerializableInstance(instance);
+ }
+ this.instanceIndex = index;
+ this.isTraining = isTraining;
+ this.isTesting = isTesting;
+ }
- /**
- * Gets the class id.
- *
- * @return the true class of the vector.
- */
- public int getClassId() {
- // return classId;
- return (int) instance.classValue();
- }
+ /**
+ * Gets the single instance of InstanceEvent.
+ *
+ * @return the instance.
+ */
+ public Instance getInstance() {
+ return instance;
+ }
- /**
- * Checks if is training.
- *
- * @return true if this is training data.
- */
- public boolean isTraining() {
- return isTraining;
- }
-
- /**
- * Set training flag.
- *
- * @param training flag.
- */
- public void setTraining(boolean training) {
- this.isTraining = training;
- }
-
- /**
- * Checks if is testing.
- *
- * @return true if this is testing data.
- */
- public boolean isTesting(){
- return isTesting;
- }
-
- /**
- * Set testing flag.
- *
- * @param testing flag.
- */
- public void setTesting(boolean testing) {
- this.isTesting = testing;
- }
+ /**
+ * Gets the instance index.
+ *
+ * @return the index of the data vector.
+ */
+ public long getInstanceIndex() {
+ return instanceIndex;
+ }
- /**
- * Gets the classifier index.
- *
- * @return the classifier index
- */
- public int getClassifierIndex() {
- return classifierIndex;
- }
+ /**
+ * Gets the class id.
+ *
+ * @return the true class of the vector.
+ */
+ public int getClassId() {
+ // return classId;
+ return (int) instance.classValue();
+ }
- /**
- * Sets the classifier index.
- *
- * @param classifierIndex the new classifier index
- */
- public void setClassifierIndex(int classifierIndex) {
- this.classifierIndex = classifierIndex;
- }
+ /**
+ * Checks if is training.
+ *
+ * @return true if this is training data.
+ */
+ public boolean isTraining() {
+ return isTraining;
+ }
- /**
- * Gets the evaluation index.
- *
- * @return the evaluation index
- */
- public int getEvaluationIndex() {
- return evaluationIndex;
- }
+ /**
+ * Set training flag.
+ *
+ * @param training
+ * flag.
+ */
+ public void setTraining(boolean training) {
+ this.isTraining = training;
+ }
- /**
- * Sets the evaluation index.
- *
- * @param evaluationIndex the new evaluation index
- */
- public void setEvaluationIndex(int evaluationIndex) {
- this.evaluationIndex = evaluationIndex;
- }
+ /**
+ * Checks if is testing.
+ *
+ * @return true if this is testing data.
+ */
+ public boolean isTesting() {
+ return isTesting;
+ }
- /* (non-Javadoc)
- * @see samoa.core.ContentEvent#getKey(int)
- */
- public String getKey(int key) {
- if (key == 0)
- return Long.toString(this.getEvaluationIndex());
- else return Long.toString(10000
- * this.getEvaluationIndex()
- + this.getClassifierIndex());
- }
+ /**
+ * Set testing flag.
+ *
+ * @param testing
+ * flag.
+ */
+ public void setTesting(boolean testing) {
+ this.isTesting = testing;
+ }
- @Override
- public String getKey() {
- //System.out.println("InstanceContentEvent "+Long.toString(this.instanceIndex));
- return Long.toString(this.getClassifierIndex());
- }
+ /**
+ * Gets the classifier index.
+ *
+ * @return the classifier index
+ */
+ public int getClassifierIndex() {
+ return classifierIndex;
+ }
- @Override
- public void setKey(String str) {
- this.instanceIndex = Long.parseLong(str);
- }
+ /**
+ * Sets the classifier index.
+ *
+ * @param classifierIndex
+ * the new classifier index
+ */
+ public void setClassifierIndex(int classifierIndex) {
+ this.classifierIndex = classifierIndex;
+ }
- @Override
- public boolean isLastEvent() {
- return isLast;
- }
+ /**
+ * Gets the evaluation index.
+ *
+ * @return the evaluation index
+ */
+ public int getEvaluationIndex() {
+ return evaluationIndex;
+ }
- public void setLast(boolean isLast) {
- this.isLast = isLast;
- }
+ /**
+ * Sets the evaluation index.
+ *
+ * @param evaluationIndex
+ * the new evaluation index
+ */
+ public void setEvaluationIndex(int evaluationIndex) {
+ this.evaluationIndex = evaluationIndex;
+ }
-
-
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.ContentEvent#getKey(int)
+ */
+ public String getKey(int key) {
+ if (key == 0)
+ return Long.toString(this.getEvaluationIndex());
+ else
+ return Long.toString(10000
+ * this.getEvaluationIndex()
+ + this.getClassifierIndex());
+ }
+
+ @Override
+ public String getKey() {
+ // System.out.println("InstanceContentEvent "+Long.toString(this.instanceIndex));
+ return Long.toString(this.getClassifierIndex());
+ }
+
+ @Override
+ public void setKey(String str) {
+ this.instanceIndex = Long.parseLong(str);
+ }
+
+ @Override
+ public boolean isLastEvent() {
+ return isLast;
+ }
+
+ public void setLast(boolean isLast) {
+ this.isLast = isLast;
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/InstancesContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/InstancesContentEvent.java
index ff005b6..ce5937a 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/InstancesContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/InstancesContentEvent.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.learners;
/*
@@ -31,8 +30,8 @@
import com.yahoo.labs.samoa.instances.Instance;
import java.util.LinkedList;
import java.util.List;
-//import weka.core.Instance;
+//import weka.core.Instance;
/**
* The Class InstanceEvent.
@@ -40,154 +39,161 @@
@Immutable
final public class InstancesContentEvent implements ContentEvent {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -8620668863064613845L;
- private long instanceIndex;
- private int classifierIndex;
- private int evaluationIndex;
- //private SerializableInstance instance;
- private boolean isTraining;
- private boolean isTesting;
- private boolean isLast = false;
-
- public InstancesContentEvent() {
-
- }
+ private static final long serialVersionUID = -8620668863064613845L;
+ private long instanceIndex;
+ private int classifierIndex;
+ private int evaluationIndex;
+ // private SerializableInstance instance;
+ private boolean isTraining;
+ private boolean isTesting;
+ private boolean isLast = false;
- /**
- * Instantiates a new instance event.
- *
- * @param index the index
- * @param instance the instance
- * @param isTraining the is training
- */
- public InstancesContentEvent(long index,// Instance instance,
- boolean isTraining, boolean isTesting) {
- /*if (instance != null) {
- this.instance = new SerializableInstance(instance);
- }*/
- this.instanceIndex = index;
- this.isTraining = isTraining;
- this.isTesting = isTesting;
- }
-
- public InstancesContentEvent(InstanceContentEvent event){
- this.instanceIndex = event.getInstanceIndex();
- this.isTraining = event.isTraining();
- this.isTesting = event.isTesting();
- }
+ public InstancesContentEvent() {
- protected List<Instance> instanceList = new LinkedList<Instance>();
-
- public void add(Instance instance){
- instanceList.add(new SerializableInstance(instance));
- }
-
- /**
- * Gets the single instance of InstanceEvent.
- *
- * @return the instance.
- */
- public Instance[] getInstances() {
- return instanceList.toArray(new Instance[instanceList.size()]);
- }
+ }
- /**
- * Gets the instance index.
- *
- * @return the index of the data vector.
- */
- public long getInstanceIndex() {
- return instanceIndex;
- }
+ /**
+ * Instantiates a new instance event.
+ *
+ * @param index
+ * the index
+ * @param instance
+ * the instance
+ * @param isTraining
+ * the is training
+ */
+ public InstancesContentEvent(long index,// Instance instance,
+ boolean isTraining, boolean isTesting) {
+ /*
+ * if (instance != null) { this.instance = new
+ * SerializableInstance(instance); }
+ */
+ this.instanceIndex = index;
+ this.isTraining = isTraining;
+ this.isTesting = isTesting;
+ }
- /**
- * Checks if is training.
- *
- * @return true if this is training data.
- */
- public boolean isTraining() {
- return isTraining;
- }
-
- /**
- * Checks if is testing.
- *
- * @return true if this is testing data.
- */
- public boolean isTesting(){
- return isTesting;
- }
+ public InstancesContentEvent(InstanceContentEvent event) {
+ this.instanceIndex = event.getInstanceIndex();
+ this.isTraining = event.isTraining();
+ this.isTesting = event.isTesting();
+ }
- /**
- * Gets the classifier index.
- *
- * @return the classifier index
- */
- public int getClassifierIndex() {
- return classifierIndex;
- }
+ protected List<Instance> instanceList = new LinkedList<Instance>();
- /**
- * Sets the classifier index.
- *
- * @param classifierIndex the new classifier index
- */
- public void setClassifierIndex(int classifierIndex) {
- this.classifierIndex = classifierIndex;
- }
+ public void add(Instance instance) {
+ instanceList.add(new SerializableInstance(instance));
+ }
- /**
- * Gets the evaluation index.
- *
- * @return the evaluation index
- */
- public int getEvaluationIndex() {
- return evaluationIndex;
- }
+ /**
+ * Gets the single instance of InstanceEvent.
+ *
+ * @return the instance.
+ */
+ public Instance[] getInstances() {
+ return instanceList.toArray(new Instance[instanceList.size()]);
+ }
- /**
- * Sets the evaluation index.
- *
- * @param evaluationIndex the new evaluation index
- */
- public void setEvaluationIndex(int evaluationIndex) {
- this.evaluationIndex = evaluationIndex;
- }
+ /**
+ * Gets the instance index.
+ *
+ * @return the index of the data vector.
+ */
+ public long getInstanceIndex() {
+ return instanceIndex;
+ }
- /* (non-Javadoc)
- * @see samoa.core.ContentEvent#getKey(int)
- */
- public String getKey(int key) {
- if (key == 0)
- return Long.toString(this.getEvaluationIndex());
- else return Long.toString(10000
- * this.getEvaluationIndex()
- + this.getClassifierIndex());
- }
+ /**
+ * Checks if is training.
+ *
+ * @return true if this is training data.
+ */
+ public boolean isTraining() {
+ return isTraining;
+ }
- @Override
- public String getKey() {
- //System.out.println("InstanceContentEvent "+Long.toString(this.instanceIndex));
- return Long.toString(this.getClassifierIndex());
- }
+ /**
+ * Checks if is testing.
+ *
+ * @return true if this is testing data.
+ */
+ public boolean isTesting() {
+ return isTesting;
+ }
- @Override
- public void setKey(String str) {
- this.instanceIndex = Long.parseLong(str);
- }
+ /**
+ * Gets the classifier index.
+ *
+ * @return the classifier index
+ */
+ public int getClassifierIndex() {
+ return classifierIndex;
+ }
- @Override
- public boolean isLastEvent() {
- return isLast;
- }
+ /**
+ * Sets the classifier index.
+ *
+ * @param classifierIndex
+ * the new classifier index
+ */
+ public void setClassifierIndex(int classifierIndex) {
+ this.classifierIndex = classifierIndex;
+ }
- public void setLast(boolean isLast) {
- this.isLast = isLast;
- }
+ /**
+ * Gets the evaluation index.
+ *
+ * @return the evaluation index
+ */
+ public int getEvaluationIndex() {
+ return evaluationIndex;
+ }
-
-
+ /**
+ * Sets the evaluation index.
+ *
+ * @param evaluationIndex
+ * the new evaluation index
+ */
+ public void setEvaluationIndex(int evaluationIndex) {
+ this.evaluationIndex = evaluationIndex;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.ContentEvent#getKey(int)
+ */
+ public String getKey(int key) {
+ if (key == 0)
+ return Long.toString(this.getEvaluationIndex());
+ else
+ return Long.toString(10000
+ * this.getEvaluationIndex()
+ + this.getClassifierIndex());
+ }
+
+ @Override
+ public String getKey() {
+ // System.out.println("InstanceContentEvent "+Long.toString(this.instanceIndex));
+ return Long.toString(this.getClassifierIndex());
+ }
+
+ @Override
+ public void setKey(String str) {
+ this.instanceIndex = Long.parseLong(str);
+ }
+
+ @Override
+ public boolean isLastEvent() {
+ return isLast;
+ }
+
+ public void setLast(boolean isLast) {
+ this.isLast = isLast;
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/Learner.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/Learner.java
index 993ca47..636b023 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/Learner.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/Learner.java
@@ -29,34 +29,36 @@
import java.util.Set;
/**
- * The Interface Classifier.
- * Initializing Classifier should initalize PI to connect the Classifier with the input stream
- * and initialize result stream so that other PI can connect to the classification result of this classifier
+ * The Interface Classifier. Initializing Classifier should initalize PI to
+ * connect the Classifier with the input stream and initialize result stream so
+ * that other PI can connect to the classification result of this classifier
*/
-public interface Learner extends Serializable{
+public interface Learner extends Serializable {
- /**
- * Inits the Learner object.
- *
- * @param topologyBuilder the topology builder
- * @param dataset the dataset
- * @param parallelism the parallelism
- */
- public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism);
-
- /**
- * Gets the input processing item.
- *
- * @return the input processing item
- */
- public Processor getInputProcessor();
+ /**
+ * Inits the Learner object.
+ *
+ * @param topologyBuilder
+ * the topology builder
+ * @param dataset
+ * the dataset
+ * @param parallelism
+ * the parallelism
+ */
+ public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism);
-
- /**
- * Gets the result streams
- *
- * @return the set of result streams
- */
- public Set<Stream> getResultStreams();
+ /**
+ * Gets the input processing item.
+ *
+ * @return the input processing item
+ */
+ public Processor getInputProcessor();
+
+ /**
+ * Gets the result streams
+ *
+ * @return the set of result streams
+ */
+ public Set<Stream> getResultStreams();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/ResultContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/ResultContentEvent.java
index 0879872..cb1e317 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/ResultContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/ResultContentEvent.java
@@ -28,185 +28,186 @@
* License
*/
-
/**
* The Class ResultEvent.
*/
final public class ResultContentEvent implements ContentEvent {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -2650420235386873306L;
- private long instanceIndex;
- private int classifierIndex;
- private int evaluationIndex;
- private SerializableInstance instance;
-
- private int classId;
- private double[] classVotes;
-
- private final boolean isLast;
-
- public ResultContentEvent(){
- this.isLast = false;
- }
+ private static final long serialVersionUID = -2650420235386873306L;
+ private long instanceIndex;
+ private int classifierIndex;
+ private int evaluationIndex;
+ private SerializableInstance instance;
+ private int classId;
+ private double[] classVotes;
- public ResultContentEvent(boolean isLast) {
- this.isLast = isLast;
- }
+ private final boolean isLast;
- /**
- * Instantiates a new result event.
- *
- * @param instanceIndex
- * the instance index
- * @param instance
- * the instance
- * @param classId
- * the class id
- * @param classVotes
- * the class votes
- */
- public ResultContentEvent(long instanceIndex, Instance instance, int classId,
- double[] classVotes, boolean isLast) {
- if(instance != null){
- this.instance = new SerializableInstance(instance);
- }
- this.instanceIndex = instanceIndex;
- this.classId = classId;
- this.classVotes = classVotes;
- this.isLast = isLast;
- }
+ public ResultContentEvent() {
+ this.isLast = false;
+ }
- /**
- * Gets the single instance of ResultEvent.
- *
- * @return single instance of ResultEvent
- */
- public SerializableInstance getInstance() {
- return instance;
- }
+ public ResultContentEvent(boolean isLast) {
+ this.isLast = isLast;
+ }
- /**
- * Sets the instance.
- *
- * @param instance
- * the new instance
- */
- public void setInstance(SerializableInstance instance) {
- this.instance = instance;
- }
+ /**
+ * Instantiates a new result event.
+ *
+ * @param instanceIndex
+ * the instance index
+ * @param instance
+ * the instance
+ * @param classId
+ * the class id
+ * @param classVotes
+ * the class votes
+ */
+ public ResultContentEvent(long instanceIndex, Instance instance, int classId,
+ double[] classVotes, boolean isLast) {
+ if (instance != null) {
+ this.instance = new SerializableInstance(instance);
+ }
+ this.instanceIndex = instanceIndex;
+ this.classId = classId;
+ this.classVotes = classVotes;
+ this.isLast = isLast;
+ }
- /**
- * Gets the num classes.
- *
- * @return the num classes
- */
- public int getNumClasses() { // To remove
- return instance.numClasses();
- }
+ /**
+ * Gets the single instance of ResultEvent.
+ *
+ * @return single instance of ResultEvent
+ */
+ public SerializableInstance getInstance() {
+ return instance;
+ }
- /**
- * Gets the instance index.
- *
- * @return the index of the data vector.
- */
- public long getInstanceIndex() {
- return instanceIndex;
- }
+ /**
+ * Sets the instance.
+ *
+ * @param instance
+ * the new instance
+ */
+ public void setInstance(SerializableInstance instance) {
+ this.instance = instance;
+ }
- /**
- * Gets the class id.
- *
- * @return the true class of the vector.
- */
- public int getClassId() { // To remove
- return classId;// (int) instance.classValue();//classId;
- }
+ /**
+ * Gets the num classes.
+ *
+ * @return the num classes
+ */
+ public int getNumClasses() { // To remove
+ return instance.numClasses();
+ }
- /**
- * Gets the class votes.
- *
- * @return the class votes
- */
- public double[] getClassVotes() {
- return classVotes;
- }
+ /**
+ * Gets the instance index.
+ *
+ * @return the index of the data vector.
+ */
+ public long getInstanceIndex() {
+ return instanceIndex;
+ }
- /**
- * Sets the class votes.
- *
- * @param classVotes
- * the new class votes
- */
- public void setClassVotes(double[] classVotes) {
- this.classVotes = classVotes;
- }
+ /**
+ * Gets the class id.
+ *
+ * @return the true class of the vector.
+ */
+ public int getClassId() { // To remove
+ return classId;// (int) instance.classValue();//classId;
+ }
- /**
- * Gets the classifier index.
- *
- * @return the classifier index
- */
- public int getClassifierIndex() {
- return classifierIndex;
- }
+ /**
+ * Gets the class votes.
+ *
+ * @return the class votes
+ */
+ public double[] getClassVotes() {
+ return classVotes;
+ }
- /**
- * Sets the classifier index.
- *
- * @param classifierIndex
- * the new classifier index
- */
- public void setClassifierIndex(int classifierIndex) {
- this.classifierIndex = classifierIndex;
- }
+ /**
+ * Sets the class votes.
+ *
+ * @param classVotes
+ * the new class votes
+ */
+ public void setClassVotes(double[] classVotes) {
+ this.classVotes = classVotes;
+ }
- /**
- * Gets the evaluation index.
- *
- * @return the evaluation index
- */
- public int getEvaluationIndex() {
- return evaluationIndex;
- }
+ /**
+ * Gets the classifier index.
+ *
+ * @return the classifier index
+ */
+ public int getClassifierIndex() {
+ return classifierIndex;
+ }
- /**
- * Sets the evaluation index.
- *
- * @param evaluationIndex
- * the new evaluation index
- */
- public void setEvaluationIndex(int evaluationIndex) {
- this.evaluationIndex = evaluationIndex;
- }
-
- /* (non-Javadoc)
- * @see samoa.core.ContentEvent#getKey(int)
- */
- //@Override
- public String getKey(int key) {
- if (key == 0)
- return Long.toString(this.getEvaluationIndex());
- else return Long.toString(this.getEvaluationIndex()
- + 1000 * this.getInstanceIndex());
- }
+ /**
+ * Sets the classifier index.
+ *
+ * @param classifierIndex
+ * the new classifier index
+ */
+ public void setClassifierIndex(int classifierIndex) {
+ this.classifierIndex = classifierIndex;
+ }
- @Override
- public String getKey() {
- return Long.toString(this.getEvaluationIndex()%100);
- }
+ /**
+ * Gets the evaluation index.
+ *
+ * @return the evaluation index
+ */
+ public int getEvaluationIndex() {
+ return evaluationIndex;
+ }
- @Override
- public void setKey(String str) {
- this.evaluationIndex = Integer.parseInt(str);
- }
+ /**
+ * Sets the evaluation index.
+ *
+ * @param evaluationIndex
+ * the new evaluation index
+ */
+ public void setEvaluationIndex(int evaluationIndex) {
+ this.evaluationIndex = evaluationIndex;
+ }
- @Override
- public boolean isLastEvent() {
- return isLast;
- }
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.ContentEvent#getKey(int)
+ */
+ // @Override
+ public String getKey(int key) {
+ if (key == 0)
+ return Long.toString(this.getEvaluationIndex());
+ else
+ return Long.toString(this.getEvaluationIndex()
+ + 1000 * this.getInstanceIndex());
+ }
+
+ @Override
+ public String getKey() {
+ return Long.toString(this.getEvaluationIndex() % 100);
+ }
+
+ @Override
+ public void setKey(String str) {
+ this.evaluationIndex = Integer.parseInt(str);
+ }
+
+ @Override
+ public boolean isLastEvent() {
+ return isLast;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/LocalLearner.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/LocalLearner.java
index b5c30db..80ddbd2 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/LocalLearner.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/LocalLearner.java
@@ -32,47 +32,47 @@
* @author abifet
*/
public interface LocalLearner extends Serializable {
-
- /**
- * Creates a new learner object.
- *
- * @return the learner
- */
- LocalLearner create();
-
- /**
- * Predicts the class memberships for a given instance. If an instance is
- * unclassified, the returned array elements must be all zero.
- *
- * @param inst
- * the instance to be classified
- * @return an array containing the estimated membership probabilities of the
- * test instance in each class
- */
- double[] getVotesForInstance(Instance inst);
- /**
- * Resets this classifier. It must be similar to starting a new classifier
- * from scratch.
- *
- */
- void resetLearning();
+ /**
+ * Creates a new learner object.
+ *
+ * @return the learner
+ */
+ LocalLearner create();
- /**
- * Trains this classifier incrementally using the given instance.
- *
- * @param inst
- * the instance to be used for training
- */
- void trainOnInstance(Instance inst);
-
- /**
- * Sets where to obtain the information of attributes of Instances
- *
- * @param dataset
- * the dataset that contains the information
- */
- @Deprecated
- public void setDataset(Instances dataset);
+ /**
+ * Predicts the class memberships for a given instance. If an instance is
+ * unclassified, the returned array elements must be all zero.
+ *
+ * @param inst
+ * the instance to be classified
+ * @return an array containing the estimated membership probabilities of the
+ * test instance in each class
+ */
+ double[] getVotesForInstance(Instance inst);
+
+ /**
+ * Resets this classifier. It must be similar to starting a new classifier
+ * from scratch.
+ *
+ */
+ void resetLearning();
+
+ /**
+ * Trains this classifier incrementally using the given instance.
+ *
+ * @param inst
+ * the instance to be used for training
+ */
+ void trainOnInstance(Instance inst);
+
+ /**
+ * Sets where to obtain the information of attributes of Instances
+ *
+ * @param dataset
+ * the dataset that contains the information
+ */
+ @Deprecated
+ public void setDataset(Instances dataset);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/LocalLearnerProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/LocalLearnerProcessor.java
index ae897f0..978a839 100755
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/LocalLearnerProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/LocalLearnerProcessor.java
@@ -42,176 +42,184 @@
*/
final public class LocalLearnerProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -1577910988699148691L;
+ private static final long serialVersionUID = -1577910988699148691L;
- private static final Logger logger = LoggerFactory.getLogger(LocalLearnerProcessor.class);
-
- private LocalLearner model;
- private Stream outputStream;
- private int modelId;
- private long instancesCount = 0;
+ private static final Logger logger = LoggerFactory.getLogger(LocalLearnerProcessor.class);
- /**
- * Sets the learner.
- *
- * @param model the model to set
- */
- public void setLearner(LocalLearner model) {
- this.model = model;
- }
+ private LocalLearner model;
+ private Stream outputStream;
+ private int modelId;
+ private long instancesCount = 0;
- /**
- * Gets the learner.
- *
- * @return the model
- */
- public LocalLearner getLearner() {
- return model;
- }
+ /**
+ * Sets the learner.
+ *
+ * @param model
+ * the model to set
+ */
+ public void setLearner(LocalLearner model) {
+ this.model = model;
+ }
- /**
- * Set the output streams.
- *
- * @param outputStream the new output stream
- */
- public void setOutputStream(Stream outputStream) {
- this.outputStream = outputStream;
- }
-
- /**
- * Gets the output stream.
- *
- * @return the output stream
- */
- public Stream getOutputStream() {
- return outputStream;
- }
+ /**
+ * Gets the learner.
+ *
+ * @return the model
+ */
+ public LocalLearner getLearner() {
+ return model;
+ }
- /**
- * Gets the instances count.
- *
- * @return number of observation vectors used in training iteration.
- */
- public long getInstancesCount() {
- return instancesCount;
- }
+ /**
+ * Set the output streams.
+ *
+ * @param outputStream
+ * the new output stream
+ */
+ public void setOutputStream(Stream outputStream) {
+ this.outputStream = outputStream;
+ }
- /**
- * Update stats.
- *
- * @param event the event
- */
- private void updateStats(InstanceContentEvent event) {
- Instance inst = event.getInstance();
- this.model.trainOnInstance(inst);
- this.instancesCount++;
- if (this.changeDetector != null) {
- boolean correctlyClassifies = this.correctlyClassifies(inst);
- double oldEstimation = this.changeDetector.getEstimation();
- this.changeDetector.input(correctlyClassifies ? 0 : 1);
- if (this.changeDetector.getChange() && this.changeDetector.getEstimation() > oldEstimation) {
- //Start a new classifier
- this.model.resetLearning();
- this.changeDetector.resetLearning();
- }
- }
- }
+ /**
+ * Gets the output stream.
+ *
+ * @return the output stream
+ */
+ public Stream getOutputStream() {
+ return outputStream;
+ }
- /**
- * Gets whether this classifier correctly classifies an instance. Uses
- * getVotesForInstance to obtain the prediction and the instance to obtain
- * its true class.
- *
- *
- * @param inst the instance to be classified
- * @return true if the instance is correctly classified
- */
- private boolean correctlyClassifies(Instance inst) {
- return maxIndex(model.getVotesForInstance(inst)) == (int) inst.classValue();
- }
-
- /** The test. */
- protected int test; //to delete
-
- /**
- * On event.
- *
- * @param event the event
- * @return true, if successful
- */
- @Override
- public boolean process(ContentEvent event) {
+ /**
+ * Gets the instances count.
+ *
+ * @return number of observation vectors used in training iteration.
+ */
+ public long getInstancesCount() {
+ return instancesCount;
+ }
- InstanceContentEvent inEvent = (InstanceContentEvent) event;
- Instance instance = inEvent.getInstance();
+ /**
+ * Update stats.
+ *
+ * @param event
+ * the event
+ */
+ private void updateStats(InstanceContentEvent event) {
+ Instance inst = event.getInstance();
+ this.model.trainOnInstance(inst);
+ this.instancesCount++;
+ if (this.changeDetector != null) {
+ boolean correctlyClassifies = this.correctlyClassifies(inst);
+ double oldEstimation = this.changeDetector.getEstimation();
+ this.changeDetector.input(correctlyClassifies ? 0 : 1);
+ if (this.changeDetector.getChange() && this.changeDetector.getEstimation() > oldEstimation) {
+ // Start a new classifier
+ this.model.resetLearning();
+ this.changeDetector.resetLearning();
+ }
+ }
+ }
- if (inEvent.getInstanceIndex() < 0) {
- //end learning
- ResultContentEvent outContentEvent = new ResultContentEvent(-1, instance, 0,
- new double[0], inEvent.isLastEvent());
- outContentEvent.setClassifierIndex(this.modelId);
- outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
- outputStream.put(outContentEvent);
- return false;
- }
-
- if (inEvent.isTesting()){
- double[] dist = model.getVotesForInstance(instance);
- ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(),
- instance, inEvent.getClassId(), dist, inEvent.isLastEvent());
- outContentEvent.setClassifierIndex(this.modelId);
- outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
- logger.trace(inEvent.getInstanceIndex() + " {} {}", modelId, dist);
- outputStream.put(outContentEvent);
- }
-
- if (inEvent.isTraining()) {
- updateStats(inEvent);
- }
- return false;
- }
+ /**
+ * Gets whether this classifier correctly classifies an instance. Uses
+ * getVotesForInstance to obtain the prediction and the instance to obtain its
+ * true class.
+ *
+ *
+ * @param inst
+ * the instance to be classified
+ * @return true if the instance is correctly classified
+ */
+ private boolean correctlyClassifies(Instance inst) {
+ return maxIndex(model.getVotesForInstance(inst)) == (int) inst.classValue();
+ }
- /* (non-Javadoc)
- * @see samoa.core.Processor#onCreate(int)
- */
- @Override
- public void onCreate(int id) {
- this.modelId = id;
- model = model.create();
- }
+ /** The test. */
+ protected int test; // to delete
- /* (non-Javadoc)
- * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
- */
- @Override
- public Processor newProcessor(Processor sourceProcessor) {
- LocalLearnerProcessor newProcessor = new LocalLearnerProcessor();
- LocalLearnerProcessor originProcessor = (LocalLearnerProcessor) sourceProcessor;
+ /**
+ * On event.
+ *
+ * @param event
+ * the event
+ * @return true, if successful
+ */
+ @Override
+ public boolean process(ContentEvent event) {
- if (originProcessor.getLearner() != null){
- newProcessor.setLearner(originProcessor.getLearner().create());
- }
+ InstanceContentEvent inEvent = (InstanceContentEvent) event;
+ Instance instance = inEvent.getInstance();
- if (originProcessor.getChangeDetector() != null){
- newProcessor.setChangeDetector(originProcessor.getChangeDetector());
- }
+ if (inEvent.getInstanceIndex() < 0) {
+ // end learning
+ ResultContentEvent outContentEvent = new ResultContentEvent(-1, instance, 0,
+ new double[0], inEvent.isLastEvent());
+ outContentEvent.setClassifierIndex(this.modelId);
+ outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
+ outputStream.put(outContentEvent);
+ return false;
+ }
- newProcessor.setOutputStream(originProcessor.getOutputStream());
- return newProcessor;
- }
-
- protected ChangeDetector changeDetector;
+ if (inEvent.isTesting()) {
+ double[] dist = model.getVotesForInstance(instance);
+ ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(),
+ instance, inEvent.getClassId(), dist, inEvent.isLastEvent());
+ outContentEvent.setClassifierIndex(this.modelId);
+ outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
+ logger.trace(inEvent.getInstanceIndex() + " {} {}", modelId, dist);
+ outputStream.put(outContentEvent);
+ }
- public ChangeDetector getChangeDetector() {
- return this.changeDetector;
- }
+ if (inEvent.isTraining()) {
+ updateStats(inEvent);
+ }
+ return false;
+ }
- public void setChangeDetector(ChangeDetector cd) {
- this.changeDetector = cd;
- }
-
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.Processor#onCreate(int)
+ */
+ @Override
+ public void onCreate(int id) {
+ this.modelId = id;
+ model = model.create();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
+ */
+ @Override
+ public Processor newProcessor(Processor sourceProcessor) {
+ LocalLearnerProcessor newProcessor = new LocalLearnerProcessor();
+ LocalLearnerProcessor originProcessor = (LocalLearnerProcessor) sourceProcessor;
+
+ if (originProcessor.getLearner() != null) {
+ newProcessor.setLearner(originProcessor.getLearner().create());
+ }
+
+ if (originProcessor.getChangeDetector() != null) {
+ newProcessor.setChangeDetector(originProcessor.getChangeDetector());
+ }
+
+ newProcessor.setOutputStream(originProcessor.getOutputStream());
+ return newProcessor;
+ }
+
+ protected ChangeDetector changeDetector;
+
+ public ChangeDetector getChangeDetector() {
+ return this.changeDetector;
+ }
+
+ public void setChangeDetector(ChangeDetector cd) {
+ this.changeDetector = cd;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/NaiveBayes.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/NaiveBayes.java
index 7e9cb4a..915d09b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/NaiveBayes.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/NaiveBayes.java
@@ -41,229 +41,229 @@
*/
public class NaiveBayes implements LocalLearner {
- /**
- * Default smoothing factor. For now fixed to 1E-20.
- */
- private static final double ADDITIVE_SMOOTHING_FACTOR = 1e-20;
+ /**
+ * Default smoothing factor. For now fixed to 1E-20.
+ */
+ private static final double ADDITIVE_SMOOTHING_FACTOR = 1e-20;
- /**
- * serialVersionUID for serialization
- */
- private static final long serialVersionUID = 1325775209672996822L;
+ /**
+ * serialVersionUID for serialization
+ */
+ private static final long serialVersionUID = 1325775209672996822L;
- /**
- * Instance of a logger for use in this class.
- */
- private static final Logger logger = LoggerFactory.getLogger(NaiveBayes.class);
+ /**
+ * Instance of a logger for use in this class.
+ */
+ private static final Logger logger = LoggerFactory.getLogger(NaiveBayes.class);
- /**
- * The actual model.
- */
- protected Map<Integer, GaussianNumericAttributeClassObserver> attributeObservers;
+ /**
+ * The actual model.
+ */
+ protected Map<Integer, GaussianNumericAttributeClassObserver> attributeObservers;
- /**
- * Class statistics
- */
- protected Map<Integer, Double> classInstances;
+ /**
+ * Class statistics
+ */
+ protected Map<Integer, Double> classInstances;
- /**
- * Class zero-prototypes.
- */
- protected Map<Integer, Double> classPrototypes;
-
- /**
- * Retrieve the number of classes currently known to this local model
- *
- * @return the number of classes currently known to this local model
- */
- protected int getNumberOfClasses() {
- return this.classInstances.size();
- }
+ /**
+ * Class zero-prototypes.
+ */
+ protected Map<Integer, Double> classPrototypes;
- /**
- * Track training instances seen.
- */
- protected long instancesSeen = 0L;
+ /**
+ * Retrieve the number of classes currently known to this local model
+ *
+ * @return the number of classes currently known to this local model
+ */
+ protected int getNumberOfClasses() {
+ return this.classInstances.size();
+ }
- /**
- * Explicit no-arg constructor.
- */
- public NaiveBayes() {
- // Init the model
- resetLearning();
- }
+ /**
+ * Track training instances seen.
+ */
+ protected long instancesSeen = 0L;
- /**
- * Create an instance of this LocalLearner implementation.
- */
- @Override
- public LocalLearner create() {
- return new NaiveBayes();
- }
+ /**
+ * Explicit no-arg constructor.
+ */
+ public NaiveBayes() {
+ // Init the model
+ resetLearning();
+ }
- /**
- * Predicts the class memberships for a given instance. If an instance is
- * unclassified, the returned array elements will be all zero.
- *
- * Smoothing is being implemented by the AttributeClassObserver classes. At
- * the moment, the GaussianNumericProbabilityAttributeClassObserver needs no
- * smoothing as it processes continuous variables.
- *
- * Please note that we transform the scores to log space to avoid underflow,
- * and we replace the multiplication with addition.
- *
- * The resulting scores are no longer probabilities, as a mixture of
- * probability densities and probabilities can be used in the computation.
- *
- * @param inst
- * the instance to be classified
- * @return an array containing the estimated membership scores of the test
- * instance in each class, in log space.
- */
- @Override
- public double[] getVotesForInstance(Instance inst) {
- // Prepare the results array
- double[] votes = new double[getNumberOfClasses()];
- // Over all classes
- for (int classIndex = 0; classIndex < votes.length; classIndex++) {
- // Get the prior for this class
- votes[classIndex] = Math.log(getPrior(classIndex));
- // Iterate over the instance attributes
- for (int index = 0; index < inst.numAttributes(); index++) {
- int attributeID = inst.index(index);
- // Skip class attribute
- if (attributeID == inst.classIndex())
- continue;
- Double value = inst.value(attributeID);
- // Get the observer for the given attribute
- GaussianNumericAttributeClassObserver obs = attributeObservers.get(attributeID);
- // Init the estimator to null by default
- GaussianEstimator estimator = null;
- if (obs != null && obs.getEstimator(classIndex) != null) {
- // Get the estimator
- estimator = obs.getEstimator(classIndex);
- }
- double valueNonZero;
- // The null case should be handled by smoothing!
- if (estimator != null) {
- // Get the score for a NON-ZERO attribute value
- valueNonZero = estimator.probabilityDensity(value);
- }
- // We don't have an estimator
- else {
- // Assign a very small probability that we do see this value
- valueNonZero = ADDITIVE_SMOOTHING_FACTOR;
- }
- votes[classIndex] += Math.log(valueNonZero); // - Math.log(valueZero);
- }
- // Check for null in the case of prequential evaluation
- if (this.classPrototypes.get(classIndex) != null) {
- // Add the prototype for the class, already in log space
- votes[classIndex] += Math.log(this.classPrototypes.get(classIndex));
- }
- }
- return votes;
- }
-
- /**
- * Compute the prior for the given classIndex.
- *
- * Implemented by maximum likelihood at the moment.
- *
- * @param classIndex
- * Id of the class for which we want to compute the prior.
- * @return Prior probability for the requested class
- */
- private double getPrior(int classIndex) {
- // Maximum likelihood
- Double currentCount = this.classInstances.get(classIndex);
- if (currentCount == null || currentCount == 0)
- return 0;
- else
- return currentCount * 1. / this.instancesSeen;
- }
+ /**
+ * Create an instance of this LocalLearner implementation.
+ */
+ @Override
+ public LocalLearner create() {
+ return new NaiveBayes();
+ }
- /**
- * Resets this classifier. It must be similar to starting a new classifier
- * from scratch.
- */
- @Override
- public void resetLearning() {
- // Reset priors
- this.instancesSeen = 0L;
- this.classInstances = new HashMap<>();
- this.classPrototypes = new HashMap<>();
- // Init the attribute observers
- this.attributeObservers = new HashMap<>();
- }
+ /**
+ * Predicts the class memberships for a given instance. If an instance is
+ * unclassified, the returned array elements will be all zero.
+ *
+ * Smoothing is being implemented by the AttributeClassObserver classes. At
+ * the moment, the GaussianNumericProbabilityAttributeClassObserver needs no
+ * smoothing as it processes continuous variables.
+ *
+ * Please note that we transform the scores to log space to avoid underflow,
+ * and we replace the multiplication with addition.
+ *
+ * The resulting scores are no longer probabilities, as a mixture of
+ * probability densities and probabilities can be used in the computation.
+ *
+ * @param inst
+ * the instance to be classified
+ * @return an array containing the estimated membership scores of the test
+ * instance in each class, in log space.
+ */
+ @Override
+ public double[] getVotesForInstance(Instance inst) {
+ // Prepare the results array
+ double[] votes = new double[getNumberOfClasses()];
+ // Over all classes
+ for (int classIndex = 0; classIndex < votes.length; classIndex++) {
+ // Get the prior for this class
+ votes[classIndex] = Math.log(getPrior(classIndex));
+ // Iterate over the instance attributes
+ for (int index = 0; index < inst.numAttributes(); index++) {
+ int attributeID = inst.index(index);
+ // Skip class attribute
+ if (attributeID == inst.classIndex())
+ continue;
+ Double value = inst.value(attributeID);
+ // Get the observer for the given attribute
+ GaussianNumericAttributeClassObserver obs = attributeObservers.get(attributeID);
+ // Init the estimator to null by default
+ GaussianEstimator estimator = null;
+ if (obs != null && obs.getEstimator(classIndex) != null) {
+ // Get the estimator
+ estimator = obs.getEstimator(classIndex);
+ }
+ double valueNonZero;
+ // The null case should be handled by smoothing!
+ if (estimator != null) {
+ // Get the score for a NON-ZERO attribute value
+ valueNonZero = estimator.probabilityDensity(value);
+ }
+ // We don't have an estimator
+ else {
+ // Assign a very small probability that we do see this value
+ valueNonZero = ADDITIVE_SMOOTHING_FACTOR;
+ }
+ votes[classIndex] += Math.log(valueNonZero); // - Math.log(valueZero);
+ }
+ // Check for null in the case of prequential evaluation
+ if (this.classPrototypes.get(classIndex) != null) {
+ // Add the prototype for the class, already in log space
+ votes[classIndex] += Math.log(this.classPrototypes.get(classIndex));
+ }
+ }
+ return votes;
+ }
- /**
- * Trains this classifier incrementally using the given instance.
- *
- * @param inst
- * the instance to be used for training
- */
- @Override
- public void trainOnInstance(Instance inst) {
- // Update class statistics with weights
- int classIndex = (int) inst.classValue();
- Double weight = this.classInstances.get(classIndex);
- if (weight == null)
- weight = 0.;
- this.classInstances.put(classIndex, weight + inst.weight());
-
- // Get the class prototype
- Double classPrototype = this.classPrototypes.get(classIndex);
- if (classPrototype == null)
- classPrototype = 1.;
-
- // Iterate over the attributes of the given instance
- for (int attributePosition = 0; attributePosition < inst
- .numAttributes(); attributePosition++) {
- // Get the attribute index - Dense -> 1:1, Sparse is remapped
- int attributeID = inst.index(attributePosition);
- // Skip class attribute
- if (attributeID == inst.classIndex())
- continue;
- // Get the attribute observer for the current attribute
- GaussianNumericAttributeClassObserver obs = this.attributeObservers
- .get(attributeID);
- // Lazy init of observers, if null, instantiate a new one
- if (obs == null) {
- // FIXME: At this point, we model everything as a numeric
- // attribute
- obs = new GaussianNumericAttributeClassObserver();
- this.attributeObservers.put(attributeID, obs);
- }
-
- // Get the probability density function under the current model
- GaussianEstimator obs_estimator = obs.getEstimator(classIndex);
- if (obs_estimator != null) {
- // Fetch the probability that the feature value is zero
- double probDens_zero_current = obs_estimator.probabilityDensity(0);
- classPrototype -= probDens_zero_current;
- }
-
- // FIXME: Sanity check on data values, for now just learn
- // Learn attribute value for given class
- obs.observeAttributeClass(inst.valueSparse(attributePosition),
- (int) inst.classValue(), inst.weight());
-
- // Update obs_estimator to fetch the pdf from the updated model
- obs_estimator = obs.getEstimator(classIndex);
- // Fetch the probability that the feature value is zero
- double probDens_zero_updated = obs_estimator.probabilityDensity(0);
- // Update the class prototype
- classPrototype += probDens_zero_updated;
- }
- // Store the class prototype
- this.classPrototypes.put(classIndex, classPrototype);
- // Count another training instance
- this.instancesSeen++;
- }
+ /**
+ * Compute the prior for the given classIndex.
+ *
+ * Implemented by maximum likelihood at the moment.
+ *
+ * @param classIndex
+ * Id of the class for which we want to compute the prior.
+ * @return Prior probability for the requested class
+ */
+ private double getPrior(int classIndex) {
+ // Maximum likelihood
+ Double currentCount = this.classInstances.get(classIndex);
+ if (currentCount == null || currentCount == 0)
+ return 0;
+ else
+ return currentCount * 1. / this.instancesSeen;
+ }
- @Override
- public void setDataset(Instances dataset) {
- // Do nothing
- }
+ /**
+ * Resets this classifier. It must be similar to starting a new classifier
+ * from scratch.
+ */
+ @Override
+ public void resetLearning() {
+ // Reset priors
+ this.instancesSeen = 0L;
+ this.classInstances = new HashMap<>();
+ this.classPrototypes = new HashMap<>();
+ // Init the attribute observers
+ this.attributeObservers = new HashMap<>();
+ }
+
+ /**
+ * Trains this classifier incrementally using the given instance.
+ *
+ * @param inst
+ * the instance to be used for training
+ */
+ @Override
+ public void trainOnInstance(Instance inst) {
+ // Update class statistics with weights
+ int classIndex = (int) inst.classValue();
+ Double weight = this.classInstances.get(classIndex);
+ if (weight == null)
+ weight = 0.;
+ this.classInstances.put(classIndex, weight + inst.weight());
+
+ // Get the class prototype
+ Double classPrototype = this.classPrototypes.get(classIndex);
+ if (classPrototype == null)
+ classPrototype = 1.;
+
+ // Iterate over the attributes of the given instance
+ for (int attributePosition = 0; attributePosition < inst
+ .numAttributes(); attributePosition++) {
+ // Get the attribute index - Dense -> 1:1, Sparse is remapped
+ int attributeID = inst.index(attributePosition);
+ // Skip class attribute
+ if (attributeID == inst.classIndex())
+ continue;
+ // Get the attribute observer for the current attribute
+ GaussianNumericAttributeClassObserver obs = this.attributeObservers
+ .get(attributeID);
+ // Lazy init of observers, if null, instantiate a new one
+ if (obs == null) {
+ // FIXME: At this point, we model everything as a numeric
+ // attribute
+ obs = new GaussianNumericAttributeClassObserver();
+ this.attributeObservers.put(attributeID, obs);
+ }
+
+ // Get the probability density function under the current model
+ GaussianEstimator obs_estimator = obs.getEstimator(classIndex);
+ if (obs_estimator != null) {
+ // Fetch the probability that the feature value is zero
+ double probDens_zero_current = obs_estimator.probabilityDensity(0);
+ classPrototype -= probDens_zero_current;
+ }
+
+ // FIXME: Sanity check on data values, for now just learn
+ // Learn attribute value for given class
+ obs.observeAttributeClass(inst.valueSparse(attributePosition),
+ (int) inst.classValue(), inst.weight());
+
+ // Update obs_estimator to fetch the pdf from the updated model
+ obs_estimator = obs.getEstimator(classIndex);
+ // Fetch the probability that the feature value is zero
+ double probDens_zero_updated = obs_estimator.probabilityDensity(0);
+ // Update the class prototype
+ classPrototype += probDens_zero_updated;
+ }
+ // Store the class prototype
+ this.classPrototypes.put(classIndex, classPrototype);
+ // Count another training instance
+ this.instancesSeen++;
+ }
+
+ @Override
+ public void setDataset(Instances dataset) {
+ // Do nothing
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/SimpleClassifierAdapter.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/SimpleClassifierAdapter.java
index a3fb89f..75c5284 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/SimpleClassifierAdapter.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/SimpleClassifierAdapter.java
@@ -30,121 +30,125 @@
import com.yahoo.labs.samoa.moa.classifiers.functions.MajorityClass;
/**
- *
+ *
* Base class for adapting external classifiers.
- *
+ *
*/
public class SimpleClassifierAdapter implements LocalLearner, Configurable {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 4372366401338704353L;
-
- public ClassOption learnerOption = new ClassOption("learner", 'l',
- "Classifier to train.", com.yahoo.labs.samoa.moa.classifiers.Classifier.class, MajorityClass.class.getName());
- /**
- * The learner.
- */
- protected com.yahoo.labs.samoa.moa.classifiers.Classifier learner;
-
- /**
- * The is init.
- */
- protected Boolean isInit;
-
- /**
- * The dataset.
- */
- protected Instances dataset;
+ private static final long serialVersionUID = 4372366401338704353L;
- @Override
- public void setDataset(Instances dataset) {
- this.dataset = dataset;
- }
+ public ClassOption learnerOption = new ClassOption("learner", 'l',
+ "Classifier to train.", com.yahoo.labs.samoa.moa.classifiers.Classifier.class, MajorityClass.class.getName());
+ /**
+ * The learner.
+ */
+ protected com.yahoo.labs.samoa.moa.classifiers.Classifier learner;
- /**
- * Instantiates a new learner.
- *
- * @param learner the learner
- * @param dataset the dataset
- */
- public SimpleClassifierAdapter(com.yahoo.labs.samoa.moa.classifiers.Classifier learner, Instances dataset) {
- this.learner = learner.copy();
- this.isInit = false;
- this.dataset = dataset;
- }
+ /**
+ * The is init.
+ */
+ protected Boolean isInit;
- /**
- * Instantiates a new learner.
- *
- */
- public SimpleClassifierAdapter() {
- this.learner = ((com.yahoo.labs.samoa.moa.classifiers.Classifier) this.learnerOption.getValue()).copy();
- this.isInit = false;
- }
+ /**
+ * The dataset.
+ */
+ protected Instances dataset;
- /**
- * Creates a new learner object.
- *
- * @return the learner
- */
- @Override
- public SimpleClassifierAdapter create() {
- SimpleClassifierAdapter l = new SimpleClassifierAdapter(learner, dataset);
- if (dataset == null) {
- System.out.println("dataset null while creating");
- }
- return l;
- }
+ @Override
+ public void setDataset(Instances dataset) {
+ this.dataset = dataset;
+ }
- /**
- * Trains this classifier incrementally using the given instance.
- *
- * @param inst the instance to be used for training
- */
- @Override
- public void trainOnInstance(Instance inst) {
- if (!this.isInit) {
- this.isInit = true;
- InstancesHeader instances = new InstancesHeader(dataset);
- this.learner.setModelContext(instances);
- this.learner.prepareForUse();
- }
- if (inst.weight() > 0) {
- inst.setDataset(dataset);
- learner.trainOnInstance(inst);
- }
- }
+ /**
+ * Instantiates a new learner.
+ *
+ * @param learner
+ * the learner
+ * @param dataset
+ * the dataset
+ */
+ public SimpleClassifierAdapter(com.yahoo.labs.samoa.moa.classifiers.Classifier learner, Instances dataset) {
+ this.learner = learner.copy();
+ this.isInit = false;
+ this.dataset = dataset;
+ }
- /**
- * Predicts the class memberships for a given instance. If an instance is
- * unclassified, the returned array elements must be all zero.
- *
- * @param inst the instance to be classified
- * @return an array containing the estimated membership probabilities of the
- * test instance in each class
- */
- @Override
- public double[] getVotesForInstance(Instance inst) {
- double[] ret;
- inst.setDataset(dataset);
- if (!this.isInit) {
- ret = new double[dataset.numClasses()];
- } else {
- ret = learner.getVotesForInstance(inst);
- }
- return ret;
- }
+ /**
+ * Instantiates a new learner.
+ *
+ */
+ public SimpleClassifierAdapter() {
+ this.learner = ((com.yahoo.labs.samoa.moa.classifiers.Classifier) this.learnerOption.getValue()).copy();
+ this.isInit = false;
+ }
- /**
- * Resets this classifier. It must be similar to starting a new classifier
- * from scratch.
- *
- */
- @Override
- public void resetLearning() {
- learner.resetLearning();
+ /**
+ * Creates a new learner object.
+ *
+ * @return the learner
+ */
+ @Override
+ public SimpleClassifierAdapter create() {
+ SimpleClassifierAdapter l = new SimpleClassifierAdapter(learner, dataset);
+ if (dataset == null) {
+ System.out.println("dataset null while creating");
}
+ return l;
+ }
+
+ /**
+ * Trains this classifier incrementally using the given instance.
+ *
+ * @param inst
+ * the instance to be used for training
+ */
+ @Override
+ public void trainOnInstance(Instance inst) {
+ if (!this.isInit) {
+ this.isInit = true;
+ InstancesHeader instances = new InstancesHeader(dataset);
+ this.learner.setModelContext(instances);
+ this.learner.prepareForUse();
+ }
+ if (inst.weight() > 0) {
+ inst.setDataset(dataset);
+ learner.trainOnInstance(inst);
+ }
+ }
+
+ /**
+ * Predicts the class memberships for a given instance. If an instance is
+ * unclassified, the returned array elements must be all zero.
+ *
+ * @param inst
+ * the instance to be classified
+ * @return an array containing the estimated membership probabilities of the
+ * test instance in each class
+ */
+ @Override
+ public double[] getVotesForInstance(Instance inst) {
+ double[] ret;
+ inst.setDataset(dataset);
+ if (!this.isInit) {
+ ret = new double[dataset.numClasses()];
+ } else {
+ ret = learner.getVotesForInstance(inst);
+ }
+ return ret;
+ }
+
+ /**
+ * Resets this classifier. It must be similar to starting a new classifier
+ * from scratch.
+ *
+ */
+ @Override
+ public void resetLearning() {
+ learner.resetLearning();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/SingleClassifier.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/SingleClassifier.java
index affc935..46352e0 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/SingleClassifier.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/SingleClassifier.java
@@ -36,6 +36,7 @@
import com.yahoo.labs.samoa.moa.classifiers.core.driftdetection.ChangeDetector;
import com.yahoo.labs.samoa.topology.Stream;
import com.yahoo.labs.samoa.topology.TopologyBuilder;
+
/**
*
* Classifier that contain a single classifier.
@@ -43,67 +44,67 @@
*/
public final class SingleClassifier implements Learner, AdaptiveLearner, Configurable {
- private static final long serialVersionUID = 684111382631697031L;
-
- private LocalLearnerProcessor learnerP;
-
- private Stream resultStream;
+ private static final long serialVersionUID = 684111382631697031L;
- private Instances dataset;
+ private LocalLearnerProcessor learnerP;
- public ClassOption learnerOption = new ClassOption("learner", 'l',
- "Classifier to train.", LocalLearner.class, SimpleClassifierAdapter.class.getName());
-
- private TopologyBuilder builder;
+ private Stream resultStream;
- private int parallelism;
+ private Instances dataset;
+ public ClassOption learnerOption = new ClassOption("learner", 'l',
+ "Classifier to train.", LocalLearner.class, SimpleClassifierAdapter.class.getName());
- @Override
- public void init(TopologyBuilder builder, Instances dataset, int parallelism){
- this.builder = builder;
- this.dataset = dataset;
- this.parallelism = parallelism;
- this.setLayout();
- }
+ private TopologyBuilder builder;
+ private int parallelism;
- protected void setLayout() {
- learnerP = new LocalLearnerProcessor();
- learnerP.setChangeDetector(this.getChangeDetector());
- LocalLearner learner = this.learnerOption.getValue();
- learner.setDataset(this.dataset);
- learnerP.setLearner(learner);
-
- //learnerPI = this.builder.createPi(learnerP, 1);
- this.builder.addProcessor(learnerP, parallelism);
- resultStream = this.builder.createStream(learnerP);
+ @Override
+ public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
+ this.builder = builder;
+ this.dataset = dataset;
+ this.parallelism = parallelism;
+ this.setLayout();
+ }
- learnerP.setOutputStream(resultStream);
- }
+ protected void setLayout() {
+ learnerP = new LocalLearnerProcessor();
+ learnerP.setChangeDetector(this.getChangeDetector());
+ LocalLearner learner = this.learnerOption.getValue();
+ learner.setDataset(this.dataset);
+ learnerP.setLearner(learner);
- @Override
- public Processor getInputProcessor() {
- return learnerP;
- }
+ // learnerPI = this.builder.createPi(learnerP, 1);
+ this.builder.addProcessor(learnerP, parallelism);
+ resultStream = this.builder.createStream(learnerP);
- /* (non-Javadoc)
- * @see samoa.learners.Learner#getResultStreams()
- */
- @Override
- public Set<Stream> getResultStreams() {
- return ImmutableSet.of(this.resultStream);
- }
+ learnerP.setOutputStream(resultStream);
+ }
- protected ChangeDetector changeDetector;
+ @Override
+ public Processor getInputProcessor() {
+ return learnerP;
+ }
- @Override
- public ChangeDetector getChangeDetector() {
- return this.changeDetector;
- }
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.learners.Learner#getResultStreams()
+ */
+ @Override
+ public Set<Stream> getResultStreams() {
+ return ImmutableSet.of(this.resultStream);
+ }
- @Override
- public void setChangeDetector(ChangeDetector cd) {
- this.changeDetector = cd;
- }
+ protected ChangeDetector changeDetector;
+
+ @Override
+ public ChangeDetector getChangeDetector() {
+ return this.changeDetector;
+ }
+
+ @Override
+ public void setChangeDetector(ChangeDetector cd) {
+ this.changeDetector = cd;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/AdaptiveBagging.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/AdaptiveBagging.java
index aba3d1d..3bdea57 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/AdaptiveBagging.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/AdaptiveBagging.java
@@ -47,103 +47,105 @@
* The Bagging Classifier by Oza and Russell.
*/
public class AdaptiveBagging implements Learner, Configurable {
-
- /** Logger */
+
+ /** Logger */
private static final Logger logger = LoggerFactory.getLogger(AdaptiveBagging.class);
- /** The Constant serialVersionUID. */
- private static final long serialVersionUID = -2971850264864952099L;
-
- /** The base learner option. */
- public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
- "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName());
+ /** The Constant serialVersionUID. */
+ private static final long serialVersionUID = -2971850264864952099L;
- /** The ensemble size option. */
- public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
- "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
+ /** The base learner option. */
+ public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
+ "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName());
- public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'd',
+ /** The ensemble size option. */
+ public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
+ "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
+
+ public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'd',
"Drift detection method to use.", ChangeDetector.class, ADWINChangeDetector.class.getName());
- /** The distributor processor. */
- private BaggingDistributorProcessor distributorP;
+ /** The distributor processor. */
+ private BaggingDistributorProcessor distributorP;
- /** The result stream. */
- protected Stream resultStream;
-
- /** The dataset. */
- private Instances dataset;
-
- protected Learner classifier;
-
+ /** The result stream. */
+ protected Stream resultStream;
+
+ /** The dataset. */
+ private Instances dataset;
+
+ protected Learner classifier;
+
protected int parallelism;
- /**
- * Sets the layout.
- */
- protected void setLayout() {
+ /**
+ * Sets the layout.
+ */
+ protected void setLayout() {
- int sizeEnsemble = this.ensembleSizeOption.getValue();
+ int sizeEnsemble = this.ensembleSizeOption.getValue();
- distributorP = new BaggingDistributorProcessor();
- distributorP.setSizeEnsemble(sizeEnsemble);
- this.builder.addProcessor(distributorP, 1);
-
- //instantiate classifier
- classifier = this.baseLearnerOption.getValue();
- if (classifier instanceof AdaptiveLearner) {
- // logger.info("Building an AdaptiveLearner {}", classifier.getClass().getName());
- AdaptiveLearner ada = (AdaptiveLearner) classifier;
- ada.setChangeDetector((ChangeDetector) this.driftDetectionMethodOption.getValue());
- }
- classifier.init(builder, this.dataset, sizeEnsemble);
-
- PredictionCombinerProcessor predictionCombinerP= new PredictionCombinerProcessor();
- predictionCombinerP.setSizeEnsemble(sizeEnsemble);
- this.builder.addProcessor(predictionCombinerP, 1);
-
- //Streams
- resultStream = this.builder.createStream(predictionCombinerP);
- predictionCombinerP.setOutputStream(resultStream);
+ distributorP = new BaggingDistributorProcessor();
+ distributorP.setSizeEnsemble(sizeEnsemble);
+ this.builder.addProcessor(distributorP, 1);
- for (Stream subResultStream:classifier.getResultStreams()) {
- this.builder.connectInputKeyStream(subResultStream, predictionCombinerP);
- }
-
- /* The training stream. */
- Stream testingStream = this.builder.createStream(distributorP);
- this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor());
-
- /* The prediction stream. */
- Stream predictionStream = this.builder.createStream(distributorP);
- this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor());
-
- distributorP.setOutputStream(testingStream);
- distributorP.setPredictionStream(predictionStream);
- }
+ // instantiate classifier
+ classifier = this.baseLearnerOption.getValue();
+ if (classifier instanceof AdaptiveLearner) {
+ // logger.info("Building an AdaptiveLearner {}",
+ // classifier.getClass().getName());
+ AdaptiveLearner ada = (AdaptiveLearner) classifier;
+ ada.setChangeDetector((ChangeDetector) this.driftDetectionMethodOption.getValue());
+ }
+ classifier.init(builder, this.dataset, sizeEnsemble);
- /** The builder. */
- private TopologyBuilder builder;
-
-
- @Override
- public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
- this.builder = builder;
- this.dataset = dataset;
- this.parallelism = parallelism;
- this.setLayout();
- }
+ PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor();
+ predictionCombinerP.setSizeEnsemble(sizeEnsemble);
+ this.builder.addProcessor(predictionCombinerP, 1);
- @Override
- public Processor getInputProcessor() {
- return distributorP;
- }
-
- /* (non-Javadoc)
- * @see samoa.learners.Learner#getResultStreams()
- */
- @Override
- public Set<Stream> getResultStreams() {
- return ImmutableSet.of(this.resultStream);
- }
+ // Streams
+ resultStream = this.builder.createStream(predictionCombinerP);
+ predictionCombinerP.setOutputStream(resultStream);
+
+ for (Stream subResultStream : classifier.getResultStreams()) {
+ this.builder.connectInputKeyStream(subResultStream, predictionCombinerP);
+ }
+
+ /* The training stream. */
+ Stream testingStream = this.builder.createStream(distributorP);
+ this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor());
+
+ /* The prediction stream. */
+ Stream predictionStream = this.builder.createStream(distributorP);
+ this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor());
+
+ distributorP.setOutputStream(testingStream);
+ distributorP.setPredictionStream(predictionStream);
+ }
+
+ /** The builder. */
+ private TopologyBuilder builder;
+
+ @Override
+ public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
+ this.builder = builder;
+ this.dataset = dataset;
+ this.parallelism = parallelism;
+ this.setLayout();
+ }
+
+ @Override
+ public Processor getInputProcessor() {
+ return distributorP;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.learners.Learner#getResultStreams()
+ */
+ @Override
+ public Set<Stream> getResultStreams() {
+ return ImmutableSet.of(this.resultStream);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/Bagging.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/Bagging.java
index 9f99ff1..3a78933 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/Bagging.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/Bagging.java
@@ -40,99 +40,99 @@
/**
* The Bagging Classifier by Oza and Russell.
*/
-public class Bagging implements Learner , Configurable {
+public class Bagging implements Learner, Configurable {
- /** The Constant serialVersionUID. */
- private static final long serialVersionUID = -2971850264864952099L;
-
- /** The base learner option. */
- public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
- "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName());
+ /** The Constant serialVersionUID. */
+ private static final long serialVersionUID = -2971850264864952099L;
-
- /** The ensemble size option. */
- public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
- "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
+ /** The base learner option. */
+ public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
+ "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName());
- /** The distributor processor. */
- private BaggingDistributorProcessor distributorP;
-
- /** The training stream. */
- private Stream testingStream;
-
- /** The prediction stream. */
- private Stream predictionStream;
-
- /** The result stream. */
- protected Stream resultStream;
-
- /** The dataset. */
- private Instances dataset;
-
- protected Learner classifier;
-
- protected int parallelism;
+ /** The ensemble size option. */
+ public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
+ "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
- /**
- * Sets the layout.
- */
- protected void setLayout() {
+ /** The distributor processor. */
+ private BaggingDistributorProcessor distributorP;
- int sizeEnsemble = this.ensembleSizeOption.getValue();
+ /** The training stream. */
+ private Stream testingStream;
- distributorP = new BaggingDistributorProcessor();
- distributorP.setSizeEnsemble(sizeEnsemble);
- this.builder.addProcessor(distributorP, 1);
-
- //instantiate classifier
- classifier = (Learner) this.baseLearnerOption.getValue();
- classifier.init(builder, this.dataset, sizeEnsemble);
-
- PredictionCombinerProcessor predictionCombinerP= new PredictionCombinerProcessor();
- predictionCombinerP.setSizeEnsemble(sizeEnsemble);
- this.builder.addProcessor(predictionCombinerP, 1);
-
- //Streams
- resultStream = this.builder.createStream(predictionCombinerP);
- predictionCombinerP.setOutputStream(resultStream);
+ /** The prediction stream. */
+ private Stream predictionStream;
- for (Stream subResultStream:classifier.getResultStreams()) {
- this.builder.connectInputKeyStream(subResultStream, predictionCombinerP);
- }
-
- testingStream = this.builder.createStream(distributorP);
- this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor());
-
- predictionStream = this.builder.createStream(distributorP);
- this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor());
-
- distributorP.setOutputStream(testingStream);
- distributorP.setPredictionStream(predictionStream);
- }
+ /** The result stream. */
+ protected Stream resultStream;
- /** The builder. */
- private TopologyBuilder builder;
-
-
- @Override
- public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
- this.builder = builder;
- this.dataset = dataset;
- this.parallelism = parallelism;
- this.setLayout();
- }
+ /** The dataset. */
+ private Instances dataset;
- @Override
- public Processor getInputProcessor() {
- return distributorP;
- }
-
- /* (non-Javadoc)
- * @see samoa.learners.Learner#getResultStreams()
- */
- @Override
- public Set<Stream> getResultStreams() {
- Set<Stream> streams = ImmutableSet.of(this.resultStream);
- return streams;
+ protected Learner classifier;
+
+ protected int parallelism;
+
+ /**
+ * Sets the layout.
+ */
+ protected void setLayout() {
+
+ int sizeEnsemble = this.ensembleSizeOption.getValue();
+
+ distributorP = new BaggingDistributorProcessor();
+ distributorP.setSizeEnsemble(sizeEnsemble);
+ this.builder.addProcessor(distributorP, 1);
+
+ // instantiate classifier
+ classifier = (Learner) this.baseLearnerOption.getValue();
+ classifier.init(builder, this.dataset, sizeEnsemble);
+
+ PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor();
+ predictionCombinerP.setSizeEnsemble(sizeEnsemble);
+ this.builder.addProcessor(predictionCombinerP, 1);
+
+ // Streams
+ resultStream = this.builder.createStream(predictionCombinerP);
+ predictionCombinerP.setOutputStream(resultStream);
+
+ for (Stream subResultStream : classifier.getResultStreams()) {
+ this.builder.connectInputKeyStream(subResultStream, predictionCombinerP);
}
+
+ testingStream = this.builder.createStream(distributorP);
+ this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor());
+
+ predictionStream = this.builder.createStream(distributorP);
+ this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor());
+
+ distributorP.setOutputStream(testingStream);
+ distributorP.setPredictionStream(predictionStream);
+ }
+
+ /** The builder. */
+ private TopologyBuilder builder;
+
+ @Override
+ public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
+ this.builder = builder;
+ this.dataset = dataset;
+ this.parallelism = parallelism;
+ this.setLayout();
+ }
+
+ @Override
+ public Processor getInputProcessor() {
+ return distributorP;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.learners.Learner#getResultStreams()
+ */
+ @Override
+ public Set<Stream> getResultStreams() {
+ Set<Stream> streams = ImmutableSet.of(this.resultStream);
+ return streams;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java
index 65c782b..44264ac 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java
@@ -35,167 +35,174 @@
/**
* The Class BaggingDistributorPE.
*/
-public class BaggingDistributorProcessor implements Processor{
+public class BaggingDistributorProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -1550901409625192730L;
+ private static final long serialVersionUID = -1550901409625192730L;
- /** The size ensemble. */
- private int sizeEnsemble;
-
- /** The training stream. */
- private Stream trainingStream;
-
- /** The prediction stream. */
- private Stream predictionStream;
+ /** The size ensemble. */
+ private int sizeEnsemble;
- /**
- * On event.
- *
- * @param event the event
- * @return true, if successful
- */
- public boolean process(ContentEvent event) {
- InstanceContentEvent inEvent = (InstanceContentEvent) event; //((s4Event) event).getContentEvent();
- //InstanceEvent inEvent = (InstanceEvent) event;
+ /** The training stream. */
+ private Stream trainingStream;
- if (inEvent.getInstanceIndex() < 0) {
- // End learning
- predictionStream.put(event);
- return false;
- }
+ /** The prediction stream. */
+ private Stream predictionStream;
+ /**
+ * On event.
+ *
+ * @param event
+ * the event
+ * @return true, if successful
+ */
+ public boolean process(ContentEvent event) {
+ InstanceContentEvent inEvent = (InstanceContentEvent) event; // ((s4Event)
+ // event).getContentEvent();
+ // InstanceEvent inEvent = (InstanceEvent) event;
- if (inEvent.isTesting()){
- Instance trainInst = inEvent.getInstance();
- for (int i = 0; i < sizeEnsemble; i++) {
- Instance weightedInst = trainInst.copy();
- //weightedInst.setWeight(trainInst.weight() * k);
- InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
- inEvent.getInstanceIndex(), weightedInst, false, true);
- instanceContentEvent.setClassifierIndex(i);
- instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
- predictionStream.put(instanceContentEvent);
- }
- }
-
- /* Estimate model parameters using the training data. */
- if (inEvent.isTraining()) {
- train(inEvent);
- }
- return false;
- }
+ if (inEvent.getInstanceIndex() < 0) {
+ // End learning
+ predictionStream.put(event);
+ return false;
+ }
- /** The random. */
- protected Random random = new Random();
+ if (inEvent.isTesting()) {
+ Instance trainInst = inEvent.getInstance();
+ for (int i = 0; i < sizeEnsemble; i++) {
+ Instance weightedInst = trainInst.copy();
+ // weightedInst.setWeight(trainInst.weight() * k);
+ InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
+ inEvent.getInstanceIndex(), weightedInst, false, true);
+ instanceContentEvent.setClassifierIndex(i);
+ instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
+ predictionStream.put(instanceContentEvent);
+ }
+ }
- /**
- * Train.
- *
- * @param inEvent the in event
- */
- protected void train(InstanceContentEvent inEvent) {
- Instance trainInst = inEvent.getInstance();
- for (int i = 0; i < sizeEnsemble; i++) {
- int k = MiscUtils.poisson(1.0, this.random);
- if (k > 0) {
- Instance weightedInst = trainInst.copy();
- weightedInst.setWeight(trainInst.weight() * k);
- InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
- inEvent.getInstanceIndex(), weightedInst, true, false);
- instanceContentEvent.setClassifierIndex(i);
- instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
- trainingStream.put(instanceContentEvent);
- }
- }
- }
+ /* Estimate model parameters using the training data. */
+ if (inEvent.isTraining()) {
+ train(inEvent);
+ }
+ return false;
+ }
- /*
- * (non-Javadoc)
- *
- * @see org.apache.s4.core.ProcessingElement#onCreate()
- */
- @Override
- public void onCreate(int id) {
- //do nothing
- }
+ /** The random. */
+ protected Random random = new Random();
+ /**
+ * Train.
+ *
+ * @param inEvent
+ * the in event
+ */
+ protected void train(InstanceContentEvent inEvent) {
+ Instance trainInst = inEvent.getInstance();
+ for (int i = 0; i < sizeEnsemble; i++) {
+ int k = MiscUtils.poisson(1.0, this.random);
+ if (k > 0) {
+ Instance weightedInst = trainInst.copy();
+ weightedInst.setWeight(trainInst.weight() * k);
+ InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
+ inEvent.getInstanceIndex(), weightedInst, true, false);
+ instanceContentEvent.setClassifierIndex(i);
+ instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
+ trainingStream.put(instanceContentEvent);
+ }
+ }
+ }
- /**
- * Gets the training stream.
- *
- * @return the training stream
- */
- public Stream getTrainingStream() {
- return trainingStream;
- }
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.s4.core.ProcessingElement#onCreate()
+ */
+ @Override
+ public void onCreate(int id) {
+ // do nothing
+ }
- /**
- * Sets the training stream.
- *
- * @param trainingStream the new training stream
- */
- public void setOutputStream(Stream trainingStream) {
- this.trainingStream = trainingStream;
- }
+ /**
+ * Gets the training stream.
+ *
+ * @return the training stream
+ */
+ public Stream getTrainingStream() {
+ return trainingStream;
+ }
- /**
- * Gets the prediction stream.
- *
- * @return the prediction stream
- */
- public Stream getPredictionStream() {
- return predictionStream;
- }
+ /**
+ * Sets the training stream.
+ *
+ * @param trainingStream
+ * the new training stream
+ */
+ public void setOutputStream(Stream trainingStream) {
+ this.trainingStream = trainingStream;
+ }
- /**
- * Sets the prediction stream.
- *
- * @param predictionStream the new prediction stream
- */
- public void setPredictionStream(Stream predictionStream) {
- this.predictionStream = predictionStream;
- }
+ /**
+ * Gets the prediction stream.
+ *
+ * @return the prediction stream
+ */
+ public Stream getPredictionStream() {
+ return predictionStream;
+ }
- /**
- * Gets the size ensemble.
- *
- * @return the size ensemble
- */
- public int getSizeEnsemble() {
- return sizeEnsemble;
- }
+ /**
+ * Sets the prediction stream.
+ *
+ * @param predictionStream
+ * the new prediction stream
+ */
+ public void setPredictionStream(Stream predictionStream) {
+ this.predictionStream = predictionStream;
+ }
- /**
- * Sets the size ensemble.
- *
- * @param sizeEnsemble the new size ensemble
- */
- public void setSizeEnsemble(int sizeEnsemble) {
- this.sizeEnsemble = sizeEnsemble;
- }
-
-
- /* (non-Javadoc)
- * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
- */
- @Override
- public Processor newProcessor(Processor sourceProcessor) {
- BaggingDistributorProcessor newProcessor = new BaggingDistributorProcessor();
- BaggingDistributorProcessor originProcessor = (BaggingDistributorProcessor) sourceProcessor;
- if (originProcessor.getPredictionStream() != null){
- newProcessor.setPredictionStream(originProcessor.getPredictionStream());
- }
- if (originProcessor.getTrainingStream() != null){
- newProcessor.setOutputStream(originProcessor.getTrainingStream());
- }
- newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble());
- /*if (originProcessor.getLearningCurve() != null){
- newProcessor.setLearningCurve((LearningCurve) originProcessor.getLearningCurve().copy());
- }*/
- return newProcessor;
- }
+ /**
+ * Gets the size ensemble.
+ *
+ * @return the size ensemble
+ */
+ public int getSizeEnsemble() {
+ return sizeEnsemble;
+ }
+
+ /**
+ * Sets the size ensemble.
+ *
+ * @param sizeEnsemble
+ * the new size ensemble
+ */
+ public void setSizeEnsemble(int sizeEnsemble) {
+ this.sizeEnsemble = sizeEnsemble;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
+ */
+ @Override
+ public Processor newProcessor(Processor sourceProcessor) {
+ BaggingDistributorProcessor newProcessor = new BaggingDistributorProcessor();
+ BaggingDistributorProcessor originProcessor = (BaggingDistributorProcessor) sourceProcessor;
+ if (originProcessor.getPredictionStream() != null) {
+ newProcessor.setPredictionStream(originProcessor.getPredictionStream());
+ }
+ if (originProcessor.getTrainingStream() != null) {
+ newProcessor.setOutputStream(originProcessor.getTrainingStream());
+ }
+ newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble());
+ /*
+ * if (originProcessor.getLearningCurve() != null){
+ * newProcessor.setLearningCurve((LearningCurve)
+ * originProcessor.getLearningCurve().copy()); }
+ */
+ return newProcessor;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/Boosting.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/Boosting.java
index 06723e2..e81c490 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/Boosting.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/Boosting.java
@@ -40,103 +40,108 @@
/**
* The Bagging Classifier by Oza and Russell.
*/
-public class Boosting implements Learner , Configurable {
+public class Boosting implements Learner, Configurable {
- /** The Constant serialVersionUID. */
- private static final long serialVersionUID = -2971850264864952099L;
-
- /** The base learner option. */
- public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
- "Classifier to train.", Learner.class, SingleClassifier.class.getName());
+ /** The Constant serialVersionUID. */
+ private static final long serialVersionUID = -2971850264864952099L;
- /** The ensemble size option. */
- public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
- "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
+ /** The base learner option. */
+ public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
+ "Classifier to train.", Learner.class, SingleClassifier.class.getName());
- /** The distributor processor. */
- private BoostingDistributorProcessor distributorP;
+ /** The ensemble size option. */
+ public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
+ "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
- /** The result stream. */
- protected Stream resultStream;
-
- /** The dataset. */
- private Instances dataset;
-
- protected Learner classifier;
-
- protected int parallelism;
+ /** The distributor processor. */
+ private BoostingDistributorProcessor distributorP;
- /**
- * Sets the layout.
- */
- protected void setLayout() {
+ /** The result stream. */
+ protected Stream resultStream;
- int sizeEnsemble = this.ensembleSizeOption.getValue();
+ /** The dataset. */
+ private Instances dataset;
- distributorP = new BoostingDistributorProcessor();
- distributorP.setSizeEnsemble(sizeEnsemble);
- this.builder.addProcessor(distributorP, 1);
-
- //instantiate classifier
- classifier = this.baseLearnerOption.getValue();
- classifier.init(builder, this.dataset, sizeEnsemble);
-
- BoostingPredictionCombinerProcessor predictionCombinerP= new BoostingPredictionCombinerProcessor();
- predictionCombinerP.setSizeEnsemble(sizeEnsemble);
- this.builder.addProcessor(predictionCombinerP, 1);
-
- //Streams
- resultStream = this.builder.createStream(predictionCombinerP);
- predictionCombinerP.setOutputStream(resultStream);
+ protected Learner classifier;
- for (Stream subResultStream:classifier.getResultStreams()) {
- this.builder.connectInputKeyStream(subResultStream, predictionCombinerP);
- }
-
- /* The testing stream. */
- Stream testingStream = this.builder.createStream(distributorP);
- this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor());
-
- /* The prediction stream. */
- Stream predictionStream = this.builder.createStream(distributorP);
- this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor());
-
- distributorP.setOutputStream(testingStream);
- distributorP.setPredictionStream(predictionStream);
-
+ protected int parallelism;
+
+ /**
+ * Sets the layout.
+ */
+ protected void setLayout() {
+
+ int sizeEnsemble = this.ensembleSizeOption.getValue();
+
+ distributorP = new BoostingDistributorProcessor();
+ distributorP.setSizeEnsemble(sizeEnsemble);
+ this.builder.addProcessor(distributorP, 1);
+
+ // instantiate classifier
+ classifier = this.baseLearnerOption.getValue();
+ classifier.init(builder, this.dataset, sizeEnsemble);
+
+ BoostingPredictionCombinerProcessor predictionCombinerP = new BoostingPredictionCombinerProcessor();
+ predictionCombinerP.setSizeEnsemble(sizeEnsemble);
+ this.builder.addProcessor(predictionCombinerP, 1);
+
+ // Streams
+ resultStream = this.builder.createStream(predictionCombinerP);
+ predictionCombinerP.setOutputStream(resultStream);
+
+ for (Stream subResultStream : classifier.getResultStreams()) {
+ this.builder.connectInputKeyStream(subResultStream, predictionCombinerP);
+ }
+
+ /* The testing stream. */
+ Stream testingStream = this.builder.createStream(distributorP);
+ this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor());
+
+ /* The prediction stream. */
+ Stream predictionStream = this.builder.createStream(distributorP);
+ this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor());
+
+ distributorP.setOutputStream(testingStream);
+ distributorP.setPredictionStream(predictionStream);
+
// Addition to Bagging: stream to train
/* The training stream. */
- Stream trainingStream = this.builder.createStream(predictionCombinerP);
- predictionCombinerP.setTrainingStream(trainingStream);
- this.builder.connectInputKeyStream(trainingStream, classifier.getInputProcessor());
-
- }
+ Stream trainingStream = this.builder.createStream(predictionCombinerP);
+ predictionCombinerP.setTrainingStream(trainingStream);
+ this.builder.connectInputKeyStream(trainingStream, classifier.getInputProcessor());
- /** The builder. */
- private TopologyBuilder builder;
+ }
- /* (non-Javadoc)
- * @see samoa.classifiers.Classifier#init(samoa.engines.Engine, samoa.core.Stream, weka.core.Instances)
- */
-
- @Override
- public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
- this.builder = builder;
- this.dataset = dataset;
- this.parallelism = parallelism;
- this.setLayout();
- }
+ /** The builder. */
+ private TopologyBuilder builder;
- @Override
- public Processor getInputProcessor() {
- return distributorP;
- }
-
- /* (non-Javadoc)
- * @see samoa.learners.Learner#getResultStreams()
- */
- @Override
- public Set<Stream> getResultStreams() {
- return ImmutableSet.of(this.resultStream);
- }
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.classifiers.Classifier#init(samoa.engines.Engine,
+ * samoa.core.Stream, weka.core.Instances)
+ */
+
+ @Override
+ public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
+ this.builder = builder;
+ this.dataset = dataset;
+ this.parallelism = parallelism;
+ this.setLayout();
+ }
+
+ @Override
+ public Processor getInputProcessor() {
+ return distributorP;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.learners.Learner#getResultStreams()
+ */
+ @Override
+ public Set<Stream> getResultStreams() {
+ return ImmutableSet.of(this.resultStream);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingDistributorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingDistributorProcessor.java
index 7100e7e..78c2bd0 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingDistributorProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingDistributorProcessor.java
@@ -22,15 +22,14 @@
* #L%
*/
-
/**
* The Class BoostingDistributorProcessor.
*/
-public class BoostingDistributorProcessor extends BaggingDistributorProcessor{
-
- @Override
- protected void train(InstanceContentEvent inEvent) {
- // Boosting is trained from the prediction combiner, not from the input
- }
-
+public class BoostingDistributorProcessor extends BaggingDistributorProcessor {
+
+ @Override
+ protected void train(InstanceContentEvent inEvent) {
+ // Boosting is trained from the prediction combiner, not from the input
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java
index 1d8db50..8acbdc8 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java
@@ -40,137 +40,139 @@
*/
public class BoostingPredictionCombinerProcessor extends PredictionCombinerProcessor {
- private static final long serialVersionUID = -1606045723451191232L;
-
- //Weigths classifier
- protected double[] scms;
+ private static final long serialVersionUID = -1606045723451191232L;
- //Weights instance
- protected double[] swms;
+ // Weigths classifier
+ protected double[] scms;
- /**
- * On event.
- *
- * @param event the event
- * @return true, if successful
- */
- @Override
- public boolean process(ContentEvent event) {
+ // Weights instance
+ protected double[] swms;
- ResultContentEvent inEvent = (ResultContentEvent) event;
- double[] prediction = inEvent.getClassVotes();
- int instanceIndex = (int) inEvent.getInstanceIndex();
-
- addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1);
- //Boosting
- addPredictions(instanceIndex, inEvent, prediction);
-
- if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) {
- DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex);
- if (combinedVote == null){
- combinedVote = new DoubleVector();
- }
- ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(),
- inEvent.getInstance(), inEvent.getClassId(),
- combinedVote.getArrayCopy(), inEvent.isLastEvent());
- outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
- outputStream.put(outContentEvent);
- clearStatisticsInstance(instanceIndex);
- //Boosting
- computeBoosting(inEvent, instanceIndex);
- return true;
- }
- return false;
+ /**
+ * On event.
+ *
+ * @param event
+ * the event
+ * @return true, if successful
+ */
+ @Override
+ public boolean process(ContentEvent event) {
+ ResultContentEvent inEvent = (ResultContentEvent) event;
+ double[] prediction = inEvent.getClassVotes();
+ int instanceIndex = (int) inEvent.getInstanceIndex();
+
+ addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1);
+ // Boosting
+ addPredictions(instanceIndex, inEvent, prediction);
+
+ if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) {
+ DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex);
+ if (combinedVote == null) {
+ combinedVote = new DoubleVector();
+ }
+ ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(),
+ inEvent.getInstance(), inEvent.getClassId(),
+ combinedVote.getArrayCopy(), inEvent.isLastEvent());
+ outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
+ outputStream.put(outContentEvent);
+ clearStatisticsInstance(instanceIndex);
+ // Boosting
+ computeBoosting(inEvent, instanceIndex);
+ return true;
}
-
- protected Random random;
-
- protected int trainingWeightSeenByModel;
-
- @Override
- protected double getEnsembleMemberWeight(int i) {
- double em = this.swms[i] / (this.scms[i] + this.swms[i]);
- if ((em == 0.0) || (em > 0.5)) {
- return 0.0;
- }
- double Bm = em / (1.0 - em);
- return Math.log(1.0 / Bm);
- }
-
- @Override
- public void reset() {
- this.random = new Random();
- this.trainingWeightSeenByModel = 0;
- this.scms = new double[this.ensembleSize];
- this.swms = new double[this.ensembleSize];
- }
+ return false;
- private boolean correctlyClassifies(int i, Instance inst, int instanceIndex) {
- int predictedClass = (int) mapPredictions.get(instanceIndex).getValue(i);
- return predictedClass == (int) inst.classValue();
- }
-
- protected Map<Integer, DoubleVector> mapPredictions;
+ }
- private void addPredictions(int instanceIndex, ResultContentEvent inEvent, double[] prediction) {
- if (this.mapPredictions == null) {
- this.mapPredictions = new HashMap<>();
- }
- DoubleVector predictions = this.mapPredictions.get(instanceIndex);
- if (predictions == null){
- predictions = new DoubleVector();
- }
- predictions.setValue(inEvent.getClassifierIndex(), Utils.maxIndex(prediction));
- this.mapPredictions.put(instanceIndex, predictions);
- }
+ protected Random random;
- private void computeBoosting(ResultContentEvent inEvent, int instanceIndex) {
- // Starts code for Boosting
- //Send instances to train
- double lambda_d = 1.0;
- for (int i = 0; i < this.ensembleSize; i++) {
- double k = lambda_d;
- Instance inst = inEvent.getInstance();
- if (k > 0.0) {
- Instance weightedInst = inst.copy();
- weightedInst.setWeight(inst.weight() * k);
- //this.ensemble[i].trainOnInstance(weightedInst);
- InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
- inEvent.getInstanceIndex(), weightedInst, true, false);
- instanceContentEvent.setClassifierIndex(i);
- instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
- trainingStream.put(instanceContentEvent);
- }
- if (this.correctlyClassifies(i, inst, instanceIndex)){
- this.scms[i] += lambda_d;
- lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]);
- } else {
- this.swms[i] += lambda_d;
- lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]);
- }
- }
- }
-
- /**
- * Gets the training stream.
- *
- * @return the training stream
- */
- public Stream getTrainingStream() {
- return trainingStream;
- }
+ protected int trainingWeightSeenByModel;
- /**
- * Sets the training stream.
- *
- * @param trainingStream the new training stream
- */
- public void setTrainingStream(Stream trainingStream) {
- this.trainingStream = trainingStream;
- }
-
- /** The training stream. */
- private Stream trainingStream;
-
+ @Override
+ protected double getEnsembleMemberWeight(int i) {
+ double em = this.swms[i] / (this.scms[i] + this.swms[i]);
+ if ((em == 0.0) || (em > 0.5)) {
+ return 0.0;
+ }
+ double Bm = em / (1.0 - em);
+ return Math.log(1.0 / Bm);
+ }
+
+ @Override
+ public void reset() {
+ this.random = new Random();
+ this.trainingWeightSeenByModel = 0;
+ this.scms = new double[this.ensembleSize];
+ this.swms = new double[this.ensembleSize];
+ }
+
+ private boolean correctlyClassifies(int i, Instance inst, int instanceIndex) {
+ int predictedClass = (int) mapPredictions.get(instanceIndex).getValue(i);
+ return predictedClass == (int) inst.classValue();
+ }
+
+ protected Map<Integer, DoubleVector> mapPredictions;
+
+ private void addPredictions(int instanceIndex, ResultContentEvent inEvent, double[] prediction) {
+ if (this.mapPredictions == null) {
+ this.mapPredictions = new HashMap<>();
+ }
+ DoubleVector predictions = this.mapPredictions.get(instanceIndex);
+ if (predictions == null) {
+ predictions = new DoubleVector();
+ }
+ predictions.setValue(inEvent.getClassifierIndex(), Utils.maxIndex(prediction));
+ this.mapPredictions.put(instanceIndex, predictions);
+ }
+
+ private void computeBoosting(ResultContentEvent inEvent, int instanceIndex) {
+ // Starts code for Boosting
+ // Send instances to train
+ double lambda_d = 1.0;
+ for (int i = 0; i < this.ensembleSize; i++) {
+ double k = lambda_d;
+ Instance inst = inEvent.getInstance();
+ if (k > 0.0) {
+ Instance weightedInst = inst.copy();
+ weightedInst.setWeight(inst.weight() * k);
+ // this.ensemble[i].trainOnInstance(weightedInst);
+ InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
+ inEvent.getInstanceIndex(), weightedInst, true, false);
+ instanceContentEvent.setClassifierIndex(i);
+ instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
+ trainingStream.put(instanceContentEvent);
+ }
+ if (this.correctlyClassifies(i, inst, instanceIndex)) {
+ this.scms[i] += lambda_d;
+ lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]);
+ } else {
+ this.swms[i] += lambda_d;
+ lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]);
+ }
+ }
+ }
+
+ /**
+ * Gets the training stream.
+ *
+ * @return the training stream
+ */
+ public Stream getTrainingStream() {
+ return trainingStream;
+ }
+
+ /**
+ * Sets the training stream.
+ *
+ * @param trainingStream
+ * the new training stream
+ */
+ public void setTrainingStream(Stream trainingStream) {
+ this.trainingStream = trainingStream;
+ }
+
+ /** The training stream. */
+ private Stream trainingStream;
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java
index e4228d8..fff801f 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java
@@ -37,148 +37,151 @@
*/
public class PredictionCombinerProcessor implements Processor {
- private static final long serialVersionUID = -1606045723451191132L;
+ private static final long serialVersionUID = -1606045723451191132L;
- /**
- * The size ensemble.
- */
- protected int ensembleSize;
+ /**
+ * The size ensemble.
+ */
+ protected int ensembleSize;
- /**
- * The output stream.
- */
- protected Stream outputStream;
+ /**
+ * The output stream.
+ */
+ protected Stream outputStream;
- /**
- * Sets the output stream.
- *
- * @param stream the new output stream
- */
- public void setOutputStream(Stream stream) {
- outputStream = stream;
+ /**
+ * Sets the output stream.
+ *
+ * @param stream
+ * the new output stream
+ */
+ public void setOutputStream(Stream stream) {
+ outputStream = stream;
+ }
+
+ /**
+ * Gets the output stream.
+ *
+ * @return the output stream
+ */
+ public Stream getOutputStream() {
+ return outputStream;
+ }
+
+ /**
+ * Gets the size ensemble.
+ *
+ * @return the ensembleSize
+ */
+ public int getSizeEnsemble() {
+ return ensembleSize;
+ }
+
+ /**
+ * Sets the size ensemble.
+ *
+ * @param ensembleSize
+ * the new size ensemble
+ */
+ public void setSizeEnsemble(int ensembleSize) {
+ this.ensembleSize = ensembleSize;
+ }
+
+ protected Map<Integer, Integer> mapCountsforInstanceReceived;
+
+ protected Map<Integer, DoubleVector> mapVotesforInstanceReceived;
+
+ /**
+ * On event.
+ *
+ * @param event
+ * the event
+ * @return true, if successful
+ */
+ public boolean process(ContentEvent event) {
+
+ ResultContentEvent inEvent = (ResultContentEvent) event;
+ double[] prediction = inEvent.getClassVotes();
+ int instanceIndex = (int) inEvent.getInstanceIndex();
+
+ addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1);
+
+ if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) {
+ DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex);
+ if (combinedVote == null) {
+ combinedVote = new DoubleVector(new double[inEvent.getInstance().numClasses()]);
+ }
+ ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(),
+ inEvent.getInstance(), inEvent.getClassId(),
+ combinedVote.getArrayCopy(), inEvent.isLastEvent());
+ outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
+ outputStream.put(outContentEvent);
+ clearStatisticsInstance(instanceIndex);
+ return true;
}
+ return false;
- /**
- * Gets the output stream.
- *
- * @return the output stream
- */
- public Stream getOutputStream() {
- return outputStream;
+ }
+
+ @Override
+ public void onCreate(int id) {
+ this.reset();
+ }
+
+ public void reset() {
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
+ */
+ @Override
+ public Processor newProcessor(Processor sourceProcessor) {
+ PredictionCombinerProcessor newProcessor = new PredictionCombinerProcessor();
+ PredictionCombinerProcessor originProcessor = (PredictionCombinerProcessor) sourceProcessor;
+ if (originProcessor.getOutputStream() != null) {
+ newProcessor.setOutputStream(originProcessor.getOutputStream());
}
+ newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble());
+ return newProcessor;
+ }
- /**
- * Gets the size ensemble.
- *
- * @return the ensembleSize
- */
- public int getSizeEnsemble() {
- return ensembleSize;
+ protected void addStatisticsForInstanceReceived(int instanceIndex, int classifierIndex, double[] prediction, int add) {
+ if (this.mapCountsforInstanceReceived == null) {
+ this.mapCountsforInstanceReceived = new HashMap<>();
+ this.mapVotesforInstanceReceived = new HashMap<>();
}
+ DoubleVector vote = new DoubleVector(prediction);
+ if (vote.sumOfValues() > 0.0) {
+ vote.normalize();
+ DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex);
+ if (combinedVote == null) {
+ combinedVote = new DoubleVector();
+ }
+ vote.scaleValues(getEnsembleMemberWeight(classifierIndex));
+ combinedVote.addValues(vote);
- /**
- * Sets the size ensemble.
- *
- * @param ensembleSize the new size ensemble
- */
- public void setSizeEnsemble(int ensembleSize) {
- this.ensembleSize = ensembleSize;
+ this.mapVotesforInstanceReceived.put(instanceIndex, combinedVote);
}
-
- protected Map<Integer, Integer> mapCountsforInstanceReceived;
-
- protected Map<Integer, DoubleVector> mapVotesforInstanceReceived;
-
- /**
- * On event.
- *
- * @param event the event
- * @return true, if successful
- */
- public boolean process(ContentEvent event) {
-
- ResultContentEvent inEvent = (ResultContentEvent) event;
- double[] prediction = inEvent.getClassVotes();
- int instanceIndex = (int) inEvent.getInstanceIndex();
-
- addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1);
-
- if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) {
- DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex);
- if (combinedVote == null){
- combinedVote = new DoubleVector(new double[inEvent.getInstance().numClasses()]);
- }
- ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(),
- inEvent.getInstance(), inEvent.getClassId(),
- combinedVote.getArrayCopy(), inEvent.isLastEvent());
- outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
- outputStream.put(outContentEvent);
- clearStatisticsInstance(instanceIndex);
- return true;
- }
- return false;
-
+ Integer count = this.mapCountsforInstanceReceived.get(instanceIndex);
+ if (count == null) {
+ count = 0;
}
+ this.mapCountsforInstanceReceived.put(instanceIndex, count + add);
+ }
- @Override
- public void onCreate(int id) {
- this.reset();
- }
+ protected boolean hasAllVotesArrivedInstance(int instanceIndex) {
+ return (this.mapCountsforInstanceReceived.get(instanceIndex) == this.ensembleSize);
+ }
- public void reset() {
- }
+ protected void clearStatisticsInstance(int instanceIndex) {
+ this.mapCountsforInstanceReceived.remove(instanceIndex);
+ this.mapVotesforInstanceReceived.remove(instanceIndex);
+ }
+ protected double getEnsembleMemberWeight(int i) {
+ return 1.0;
+ }
- /* (non-Javadoc)
- * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
- */
- @Override
- public Processor newProcessor(Processor sourceProcessor) {
- PredictionCombinerProcessor newProcessor = new PredictionCombinerProcessor();
- PredictionCombinerProcessor originProcessor = (PredictionCombinerProcessor) sourceProcessor;
- if (originProcessor.getOutputStream() != null) {
- newProcessor.setOutputStream(originProcessor.getOutputStream());
- }
- newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble());
- return newProcessor;
- }
-
- protected void addStatisticsForInstanceReceived(int instanceIndex, int classifierIndex, double[] prediction, int add) {
- if (this.mapCountsforInstanceReceived == null) {
- this.mapCountsforInstanceReceived = new HashMap<>();
- this.mapVotesforInstanceReceived = new HashMap<>();
- }
- DoubleVector vote = new DoubleVector(prediction);
- if (vote.sumOfValues() > 0.0) {
- vote.normalize();
- DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex);
- if (combinedVote == null){
- combinedVote = new DoubleVector();
- }
- vote.scaleValues(getEnsembleMemberWeight(classifierIndex));
- combinedVote.addValues(vote);
-
- this.mapVotesforInstanceReceived.put(instanceIndex, combinedVote);
- }
- Integer count = this.mapCountsforInstanceReceived.get(instanceIndex);
- if (count == null) {
- count = 0;
- }
- this.mapCountsforInstanceReceived.put(instanceIndex, count + add);
- }
-
- protected boolean hasAllVotesArrivedInstance(int instanceIndex) {
- return (this.mapCountsforInstanceReceived.get(instanceIndex) == this.ensembleSize);
- }
-
- protected void clearStatisticsInstance(int instanceIndex) {
- this.mapCountsforInstanceReceived.remove(instanceIndex);
- this.mapVotesforInstanceReceived.remove(instanceIndex);
- }
-
- protected double getEnsembleMemberWeight(int i) {
- return 1.0;
- }
-
-
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/AMRulesRegressor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/AMRulesRegressor.java
index 268072b..a05772b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/AMRulesRegressor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/AMRulesRegressor.java
@@ -39,138 +39,139 @@
import com.yahoo.labs.samoa.topology.TopologyBuilder;
/**
- * AMRules Regressor
- * is the task for the serialized implementation of AMRules algorithm for regression rule.
- * It is adapted to SAMOA from the implementation of AMRules in MOA.
+ * AMRules Regressor is the task for the serialized implementation of AMRules
+ * algorithm for regression rule. It is adapted to SAMOA from the implementation
+ * of AMRules in MOA.
*
* @author Anh Thu Vu
- *
+ *
*/
public class AMRulesRegressor implements RegressionLearner, Configurable {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- // Options
- public FloatOption splitConfidenceOption = new FloatOption(
- "splitConfidence",
- 'c',
- "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.",
- 0.0000001, 0.0, 1.0);
+ // Options
+ public FloatOption splitConfidenceOption = new FloatOption(
+ "splitConfidence",
+ 'c',
+ "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.",
+ 0.0000001, 0.0, 1.0);
- public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
- 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.",
- 0.05, 0.0, 1.0);
+ public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
+ 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.",
+ 0.05, 0.0, 1.0);
- public IntOption gracePeriodOption = new IntOption("gracePeriod",
- 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.",
- 200, 1, Integer.MAX_VALUE);
+ public IntOption gracePeriodOption = new IntOption("gracePeriod",
+ 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.",
+ 200, 1, Integer.MAX_VALUE);
- public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H',
- "Drift Detection. Page-Hinkley.");
+ public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H',
+ "Drift Detection. Page-Hinkley.");
- public FloatOption pageHinckleyAlphaOption = new FloatOption(
- "pageHinckleyAlpha",
- 'a',
- "The alpha value to use in the Page Hinckley change detection tests.",
- 0.005, 0.0, 1.0);
+ public FloatOption pageHinckleyAlphaOption = new FloatOption(
+ "pageHinckleyAlpha",
+ 'a',
+ "The alpha value to use in the Page Hinckley change detection tests.",
+ 0.005, 0.0, 1.0);
- public IntOption pageHinckleyThresholdOption = new IntOption(
- "pageHinckleyThreshold",
- 'l',
- "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.",
- 35, 0, Integer.MAX_VALUE);
+ public IntOption pageHinckleyThresholdOption = new IntOption(
+ "pageHinckleyThreshold",
+ 'l',
+ "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.",
+ 35, 0, Integer.MAX_VALUE);
- public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A',
- "Disable anomaly Detection.");
+ public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A',
+ "Disable anomaly Detection.");
- public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption(
- "multivariateAnomalyProbabilityThresholdd",
- 'm',
- "Multivariate anomaly threshold value.",
- 0.99, 0.0, 1.0);
+ public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption(
+ "multivariateAnomalyProbabilityThresholdd",
+ 'm',
+ "Multivariate anomaly threshold value.",
+ 0.99, 0.0, 1.0);
- public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption(
- "univariateAnomalyprobabilityThreshold",
- 'u',
- "Univariate anomaly threshold value.",
- 0.10, 0.0, 1.0);
+ public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption(
+ "univariateAnomalyprobabilityThreshold",
+ 'u',
+ "Univariate anomaly threshold value.",
+ 0.10, 0.0, 1.0);
- public IntOption anomalyNumInstThresholdOption = new IntOption(
- "anomalyThreshold",
- 'n',
- "The threshold value of anomalies to be used in the anomaly detection.",
- 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies. 15.
+ public IntOption anomalyNumInstThresholdOption = new IntOption(
+ "anomalyThreshold",
+ 'n',
+ "The threshold value of anomalies to be used in the anomaly detection.",
+ 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect
+ // anomalies. 15.
- public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U',
- "unorderedRules.");
-
- public ClassOption numericObserverOption = new ClassOption("numericObserver",
- 'z', "Numeric observer.",
- FIMTDDNumericAttributeClassLimitObserver.class,
- "FIMTDDNumericAttributeClassLimitObserver");
+ public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U',
+ "unorderedRules.");
- public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption(
- "predictionFunctionOption", 'P', "The prediction function to use.", new String[]{
- "Adaptative","Perceptron", "Target Mean"}, new String[]{
- "Adaptative","Perceptron", "Target Mean"}, 0);
+ public ClassOption numericObserverOption = new ClassOption("numericObserver",
+ 'z', "Numeric observer.",
+ FIMTDDNumericAttributeClassLimitObserver.class,
+ "FIMTDDNumericAttributeClassLimitObserver");
- public FlagOption constantLearningRatioDecayOption = new FlagOption(
- "learningRatio_Decay_set_constant", 'd',
- "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
+ public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption(
+ "predictionFunctionOption", 'P', "The prediction function to use.", new String[] {
+ "Adaptative", "Perceptron", "Target Mean" }, new String[] {
+ "Adaptative", "Perceptron", "Target Mean" }, 0);
- public FloatOption learningRatioOption = new FloatOption(
- "learningRatio", 's',
- "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025);
-
- public ClassOption votingTypeOption = new ClassOption("votingType",
- 'V', "Voting Type.",
- ErrorWeightedVote.class,
- "InverseErrorWeightedVote");
-
- // Processor
- private AMRulesRegressorProcessor processor;
-
- // Stream
- private Stream resultStream;
-
- @Override
- public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) {
- this.processor = new AMRulesRegressorProcessor.Builder(dataset)
- .threshold(pageHinckleyThresholdOption.getValue())
- .alpha(pageHinckleyAlphaOption.getValue())
- .changeDetection(this.DriftDetectionOption.isSet())
- .predictionFunction(predictionFunctionOption.getChosenIndex())
- .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet())
- .learningRatio(learningRatioOption.getValue())
- .splitConfidence(splitConfidenceOption.getValue())
- .tieThreshold(tieThresholdOption.getValue())
- .gracePeriod(gracePeriodOption.getValue())
- .noAnomalyDetection(noAnomalyDetectionOption.isSet())
- .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue())
- .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue())
- .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue())
- .unorderedRules(unorderedRulesOption.isSet())
- .numericObserver((FIMTDDNumericAttributeClassLimitObserver)numericObserverOption.getValue())
- .voteType((ErrorWeightedVote)votingTypeOption.getValue())
- .build();
-
- topologyBuilder.addProcessor(processor, parallelism);
-
- this.resultStream = topologyBuilder.createStream(processor);
- this.processor.setResultStream(resultStream);
- }
-
- @Override
- public Processor getInputProcessor() {
- return processor;
- }
+ public FlagOption constantLearningRatioDecayOption = new FlagOption(
+ "learningRatio_Decay_set_constant", 'd',
+ "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
- @Override
- public Set<Stream> getResultStreams() {
- return ImmutableSet.of(this.resultStream);
- }
+ public FloatOption learningRatioOption = new FloatOption(
+ "learningRatio", 's',
+ "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025);
+
+ public ClassOption votingTypeOption = new ClassOption("votingType",
+ 'V', "Voting Type.",
+ ErrorWeightedVote.class,
+ "InverseErrorWeightedVote");
+
+ // Processor
+ private AMRulesRegressorProcessor processor;
+
+ // Stream
+ private Stream resultStream;
+
+ @Override
+ public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) {
+ this.processor = new AMRulesRegressorProcessor.Builder(dataset)
+ .threshold(pageHinckleyThresholdOption.getValue())
+ .alpha(pageHinckleyAlphaOption.getValue())
+ .changeDetection(this.DriftDetectionOption.isSet())
+ .predictionFunction(predictionFunctionOption.getChosenIndex())
+ .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet())
+ .learningRatio(learningRatioOption.getValue())
+ .splitConfidence(splitConfidenceOption.getValue())
+ .tieThreshold(tieThresholdOption.getValue())
+ .gracePeriod(gracePeriodOption.getValue())
+ .noAnomalyDetection(noAnomalyDetectionOption.isSet())
+ .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue())
+ .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue())
+ .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue())
+ .unorderedRules(unorderedRulesOption.isSet())
+ .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue())
+ .voteType((ErrorWeightedVote) votingTypeOption.getValue())
+ .build();
+
+ topologyBuilder.addProcessor(processor, parallelism);
+
+ this.resultStream = topologyBuilder.createStream(processor);
+ this.processor.setResultStream(resultStream);
+ }
+
+ @Override
+ public Processor getInputProcessor() {
+ return processor;
+ }
+
+ @Override
+ public Set<Stream> getResultStreams() {
+ return ImmutableSet.of(this.resultStream);
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java
index 14f5f38..9d4c5e6 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java
@@ -40,201 +40,201 @@
import com.yahoo.labs.samoa.topology.TopologyBuilder;
/**
- * Horizontal AMRules Regressor
- * is a distributed learner for regression rules learner.
- * It applies both horizontal parallelism (dividing incoming streams)
+ * Horizontal AMRules Regressor is a distributed learner for regression rules
+ * learner. It applies both horizontal parallelism (dividing incoming streams)
* and vertical parallelism on AMRules algorithm.
*
* @author Anh Thu Vu
- *
+ *
*/
public class HorizontalAMRulesRegressor implements RegressionLearner, Configurable {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 2785944439173586051L;
+ private static final long serialVersionUID = 2785944439173586051L;
- // Options
- public FloatOption splitConfidenceOption = new FloatOption(
- "splitConfidence",
- 'c',
- "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.",
- 0.0000001, 0.0, 1.0);
+ // Options
+ public FloatOption splitConfidenceOption = new FloatOption(
+ "splitConfidence",
+ 'c',
+ "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.",
+ 0.0000001, 0.0, 1.0);
- public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
- 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.",
- 0.05, 0.0, 1.0);
+ public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
+ 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.",
+ 0.05, 0.0, 1.0);
- public IntOption gracePeriodOption = new IntOption("gracePeriod",
- 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.",
- 200, 1, Integer.MAX_VALUE);
+ public IntOption gracePeriodOption = new IntOption("gracePeriod",
+ 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.",
+ 200, 1, Integer.MAX_VALUE);
- public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H',
- "Drift Detection. Page-Hinkley.");
+ public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H',
+ "Drift Detection. Page-Hinkley.");
- public FloatOption pageHinckleyAlphaOption = new FloatOption(
- "pageHinckleyAlpha",
- 'a',
- "The alpha value to use in the Page Hinckley change detection tests.",
- 0.005, 0.0, 1.0);
+ public FloatOption pageHinckleyAlphaOption = new FloatOption(
+ "pageHinckleyAlpha",
+ 'a',
+ "The alpha value to use in the Page Hinckley change detection tests.",
+ 0.005, 0.0, 1.0);
- public IntOption pageHinckleyThresholdOption = new IntOption(
- "pageHinckleyThreshold",
- 'l',
- "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.",
- 35, 0, Integer.MAX_VALUE);
+ public IntOption pageHinckleyThresholdOption = new IntOption(
+ "pageHinckleyThreshold",
+ 'l',
+ "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.",
+ 35, 0, Integer.MAX_VALUE);
- public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A',
- "Disable anomaly Detection.");
+ public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A',
+ "Disable anomaly Detection.");
- public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption(
- "multivariateAnomalyProbabilityThresholdd",
- 'm',
- "Multivariate anomaly threshold value.",
- 0.99, 0.0, 1.0);
+ public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption(
+ "multivariateAnomalyProbabilityThresholdd",
+ 'm',
+ "Multivariate anomaly threshold value.",
+ 0.99, 0.0, 1.0);
- public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption(
- "univariateAnomalyprobabilityThreshold",
- 'u',
- "Univariate anomaly threshold value.",
- 0.10, 0.0, 1.0);
+ public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption(
+ "univariateAnomalyprobabilityThreshold",
+ 'u',
+ "Univariate anomaly threshold value.",
+ 0.10, 0.0, 1.0);
- public IntOption anomalyNumInstThresholdOption = new IntOption(
- "anomalyThreshold",
- 'n',
- "The threshold value of anomalies to be used in the anomaly detection.",
- 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies. 15.
+ public IntOption anomalyNumInstThresholdOption = new IntOption(
+ "anomalyThreshold",
+ 'n',
+ "The threshold value of anomalies to be used in the anomaly detection.",
+ 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect
+ // anomalies. 15.
- public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U',
- "unorderedRules.");
+ public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U',
+ "unorderedRules.");
- public ClassOption numericObserverOption = new ClassOption("numericObserver",
- 'z', "Numeric observer.",
- FIMTDDNumericAttributeClassLimitObserver.class,
- "FIMTDDNumericAttributeClassLimitObserver");
+ public ClassOption numericObserverOption = new ClassOption("numericObserver",
+ 'z', "Numeric observer.",
+ FIMTDDNumericAttributeClassLimitObserver.class,
+ "FIMTDDNumericAttributeClassLimitObserver");
- public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption(
- "predictionFunctionOption", 'P', "The prediction function to use.", new String[]{
- "Adaptative","Perceptron", "Target Mean"}, new String[]{
- "Adaptative","Perceptron", "Target Mean"}, 0);
+ public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption(
+ "predictionFunctionOption", 'P', "The prediction function to use.", new String[] {
+ "Adaptative", "Perceptron", "Target Mean" }, new String[] {
+ "Adaptative", "Perceptron", "Target Mean" }, 0);
- public FlagOption constantLearningRatioDecayOption = new FlagOption(
- "learningRatio_Decay_set_constant", 'd',
- "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
+ public FlagOption constantLearningRatioDecayOption = new FlagOption(
+ "learningRatio_Decay_set_constant", 'd',
+ "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
- public FloatOption learningRatioOption = new FloatOption(
- "learningRatio", 's',
- "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025);
+ public FloatOption learningRatioOption = new FloatOption(
+ "learningRatio", 's',
+ "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025);
- public MultiChoiceOption votingTypeOption = new MultiChoiceOption(
- "votingType", 'V', "Voting Type.", new String[]{
- "InverseErrorWeightedVote","UniformWeightedVote"}, new String[]{
- "InverseErrorWeightedVote","UniformWeightedVote"}, 0);
-
- public IntOption learnerParallelismOption = new IntOption(
- "leanerParallelism",
- 'p',
- "The number of local statistics PI to do distributed computation",
- 1, 1, Integer.MAX_VALUE);
- public IntOption ruleSetParallelismOption = new IntOption(
- "modelParallelism",
- 'r',
- "The number of replicated model (rule set) PIs",
- 1, 1, Integer.MAX_VALUE);
+ public MultiChoiceOption votingTypeOption = new MultiChoiceOption(
+ "votingType", 'V', "Voting Type.", new String[] {
+ "InverseErrorWeightedVote", "UniformWeightedVote" }, new String[] {
+ "InverseErrorWeightedVote", "UniformWeightedVote" }, 0);
- // Processor
- private AMRRuleSetProcessor model;
+ public IntOption learnerParallelismOption = new IntOption(
+ "leanerParallelism",
+ 'p',
+ "The number of local statistics PI to do distributed computation",
+ 1, 1, Integer.MAX_VALUE);
+ public IntOption ruleSetParallelismOption = new IntOption(
+ "modelParallelism",
+ 'r',
+ "The number of replicated model (rule set) PIs",
+ 1, 1, Integer.MAX_VALUE);
- private Stream modelResultStream;
+ // Processor
+ private AMRRuleSetProcessor model;
- private Stream rootResultStream;
+ private Stream modelResultStream;
- // private Stream resultStream;
+ private Stream rootResultStream;
- @Override
- public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) {
-
- // Create MODEL PIs
- this.model = new AMRRuleSetProcessor.Builder(dataset)
- .noAnomalyDetection(noAnomalyDetectionOption.isSet())
- .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue())
- .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue())
- .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue())
- .unorderedRules(unorderedRulesOption.isSet())
- .voteType(votingTypeOption.getChosenIndex())
- .build();
+ // private Stream resultStream;
- topologyBuilder.addProcessor(model, this.ruleSetParallelismOption.getValue());
-
- // MODEL PIs streams
- Stream forwardToRootStream = topologyBuilder.createStream(this.model);
- Stream forwardToLearnerStream = topologyBuilder.createStream(this.model);
- this.modelResultStream = topologyBuilder.createStream(this.model);
-
- this.model.setDefaultRuleStream(forwardToRootStream);
- this.model.setStatisticsStream(forwardToLearnerStream);
- this.model.setResultStream(this.modelResultStream);
-
- // Create DefaultRule PI
- AMRDefaultRuleProcessor root = new AMRDefaultRuleProcessor.Builder(dataset)
- .threshold(pageHinckleyThresholdOption.getValue())
- .alpha(pageHinckleyAlphaOption.getValue())
- .changeDetection(this.DriftDetectionOption.isSet())
- .predictionFunction(predictionFunctionOption.getChosenIndex())
- .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet())
- .learningRatio(learningRatioOption.getValue())
- .splitConfidence(splitConfidenceOption.getValue())
- .tieThreshold(tieThresholdOption.getValue())
- .gracePeriod(gracePeriodOption.getValue())
- .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue())
- .build();
-
- topologyBuilder.addProcessor(root);
-
- // Default Rule PI streams
- Stream newRuleStream = topologyBuilder.createStream(root);
- this.rootResultStream = topologyBuilder.createStream(root);
-
- root.setRuleStream(newRuleStream);
- root.setResultStream(this.rootResultStream);
-
- // Create Learner PIs
- AMRLearnerProcessor learner = new AMRLearnerProcessor.Builder(dataset)
- .splitConfidence(splitConfidenceOption.getValue())
- .tieThreshold(tieThresholdOption.getValue())
- .gracePeriod(gracePeriodOption.getValue())
- .noAnomalyDetection(noAnomalyDetectionOption.isSet())
- .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue())
- .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue())
- .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue())
- .build();
-
- topologyBuilder.addProcessor(learner, this.learnerParallelismOption.getValue());
+ @Override
+ public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) {
- Stream predicateStream = topologyBuilder.createStream(learner);
- learner.setOutputStream(predicateStream);
-
- // Connect streams
- // to MODEL
- topologyBuilder.connectInputAllStream(newRuleStream, this.model);
- topologyBuilder.connectInputAllStream(predicateStream, this.model);
- // to ROOT
- topologyBuilder.connectInputShuffleStream(forwardToRootStream, root);
- // to LEARNER
- topologyBuilder.connectInputKeyStream(forwardToLearnerStream, learner);
- topologyBuilder.connectInputAllStream(newRuleStream, learner);
- }
+ // Create MODEL PIs
+ this.model = new AMRRuleSetProcessor.Builder(dataset)
+ .noAnomalyDetection(noAnomalyDetectionOption.isSet())
+ .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue())
+ .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue())
+ .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue())
+ .unorderedRules(unorderedRulesOption.isSet())
+ .voteType(votingTypeOption.getChosenIndex())
+ .build();
- @Override
- public Processor getInputProcessor() {
- return model;
- }
-
- @Override
- public Set<Stream> getResultStreams() {
- Set<Stream> streams = ImmutableSet.of(this.modelResultStream,this.rootResultStream);
- return streams;
- }
+ topologyBuilder.addProcessor(model, this.ruleSetParallelismOption.getValue());
+
+ // MODEL PIs streams
+ Stream forwardToRootStream = topologyBuilder.createStream(this.model);
+ Stream forwardToLearnerStream = topologyBuilder.createStream(this.model);
+ this.modelResultStream = topologyBuilder.createStream(this.model);
+
+ this.model.setDefaultRuleStream(forwardToRootStream);
+ this.model.setStatisticsStream(forwardToLearnerStream);
+ this.model.setResultStream(this.modelResultStream);
+
+ // Create DefaultRule PI
+ AMRDefaultRuleProcessor root = new AMRDefaultRuleProcessor.Builder(dataset)
+ .threshold(pageHinckleyThresholdOption.getValue())
+ .alpha(pageHinckleyAlphaOption.getValue())
+ .changeDetection(this.DriftDetectionOption.isSet())
+ .predictionFunction(predictionFunctionOption.getChosenIndex())
+ .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet())
+ .learningRatio(learningRatioOption.getValue())
+ .splitConfidence(splitConfidenceOption.getValue())
+ .tieThreshold(tieThresholdOption.getValue())
+ .gracePeriod(gracePeriodOption.getValue())
+ .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue())
+ .build();
+
+ topologyBuilder.addProcessor(root);
+
+ // Default Rule PI streams
+ Stream newRuleStream = topologyBuilder.createStream(root);
+ this.rootResultStream = topologyBuilder.createStream(root);
+
+ root.setRuleStream(newRuleStream);
+ root.setResultStream(this.rootResultStream);
+
+ // Create Learner PIs
+ AMRLearnerProcessor learner = new AMRLearnerProcessor.Builder(dataset)
+ .splitConfidence(splitConfidenceOption.getValue())
+ .tieThreshold(tieThresholdOption.getValue())
+ .gracePeriod(gracePeriodOption.getValue())
+ .noAnomalyDetection(noAnomalyDetectionOption.isSet())
+ .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue())
+ .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue())
+ .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue())
+ .build();
+
+ topologyBuilder.addProcessor(learner, this.learnerParallelismOption.getValue());
+
+ Stream predicateStream = topologyBuilder.createStream(learner);
+ learner.setOutputStream(predicateStream);
+
+ // Connect streams
+ // to MODEL
+ topologyBuilder.connectInputAllStream(newRuleStream, this.model);
+ topologyBuilder.connectInputAllStream(predicateStream, this.model);
+ // to ROOT
+ topologyBuilder.connectInputShuffleStream(forwardToRootStream, root);
+ // to LEARNER
+ topologyBuilder.connectInputKeyStream(forwardToLearnerStream, learner);
+ topologyBuilder.connectInputAllStream(newRuleStream, learner);
+ }
+
+ @Override
+ public Processor getInputProcessor() {
+ return model;
+ }
+
+ @Override
+ public Set<Stream> getResultStreams() {
+ Set<Stream> streams = ImmutableSet.of(this.modelResultStream, this.rootResultStream);
+ return streams;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java
index 597becb..131f086 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java
@@ -38,163 +38,163 @@
import com.yahoo.labs.samoa.topology.TopologyBuilder;
/**
- * Vertical AMRules Regressor
- * is a distributed learner for regression rules learner.
- * It applies vertical parallelism on AMRules regressor.
+ * Vertical AMRules Regressor is a distributed learner for regression rules
+ * learner. It applies vertical parallelism on AMRules regressor.
*
* @author Anh Thu Vu
- *
+ *
*/
public class VerticalAMRulesRegressor implements RegressionLearner, Configurable {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 2785944439173586051L;
+ private static final long serialVersionUID = 2785944439173586051L;
- // Options
- public FloatOption splitConfidenceOption = new FloatOption(
- "splitConfidence",
- 'c',
- "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.",
- 0.0000001, 0.0, 1.0);
+ // Options
+ public FloatOption splitConfidenceOption = new FloatOption(
+ "splitConfidence",
+ 'c',
+ "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.",
+ 0.0000001, 0.0, 1.0);
- public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
- 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.",
- 0.05, 0.0, 1.0);
+ public FloatOption tieThresholdOption = new FloatOption("tieThreshold",
+ 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.",
+ 0.05, 0.0, 1.0);
- public IntOption gracePeriodOption = new IntOption("gracePeriod",
- 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.",
- 200, 1, Integer.MAX_VALUE);
+ public IntOption gracePeriodOption = new IntOption("gracePeriod",
+ 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.",
+ 200, 1, Integer.MAX_VALUE);
- public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H',
- "Drift Detection. Page-Hinkley.");
+ public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H',
+ "Drift Detection. Page-Hinkley.");
- public FloatOption pageHinckleyAlphaOption = new FloatOption(
- "pageHinckleyAlpha",
- 'a',
- "The alpha value to use in the Page Hinckley change detection tests.",
- 00.005, 0.0, 1.0);
+ public FloatOption pageHinckleyAlphaOption = new FloatOption(
+ "pageHinckleyAlpha",
+ 'a',
+ "The alpha value to use in the Page Hinckley change detection tests.",
+ 00.005, 0.0, 1.0);
- public IntOption pageHinckleyThresholdOption = new IntOption(
- "pageHinckleyThreshold",
- 'l',
- "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.",
- 35, 0, Integer.MAX_VALUE);
+ public IntOption pageHinckleyThresholdOption = new IntOption(
+ "pageHinckleyThreshold",
+ 'l',
+ "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.",
+ 35, 0, Integer.MAX_VALUE);
- public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A',
- "Disable anomaly Detection.");
+ public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A',
+ "Disable anomaly Detection.");
- public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption(
- "multivariateAnomalyProbabilityThresholdd",
- 'm',
- "Multivariate anomaly threshold value.",
- 0.99, 0.0, 1.0);
+ public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption(
+ "multivariateAnomalyProbabilityThresholdd",
+ 'm',
+ "Multivariate anomaly threshold value.",
+ 0.99, 0.0, 1.0);
- public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption(
- "univariateAnomalyprobabilityThreshold",
- 'u',
- "Univariate anomaly threshold value.",
- 0.10, 0.0, 1.0);
+ public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption(
+ "univariateAnomalyprobabilityThreshold",
+ 'u',
+ "Univariate anomaly threshold value.",
+ 0.10, 0.0, 1.0);
- public IntOption anomalyNumInstThresholdOption = new IntOption(
- "anomalyThreshold",
- 'n',
- "The threshold value of anomalies to be used in the anomaly detection.",
- 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies. 15.
+ public IntOption anomalyNumInstThresholdOption = new IntOption(
+ "anomalyThreshold",
+ 'n',
+ "The threshold value of anomalies to be used in the anomaly detection.",
+ 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect
+ // anomalies. 15.
- public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U',
- "unorderedRules.");
+ public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U',
+ "unorderedRules.");
- public ClassOption numericObserverOption = new ClassOption("numericObserver",
- 'z', "Numeric observer.",
- FIMTDDNumericAttributeClassLimitObserver.class,
- "FIMTDDNumericAttributeClassLimitObserver");
+ public ClassOption numericObserverOption = new ClassOption("numericObserver",
+ 'z', "Numeric observer.",
+ FIMTDDNumericAttributeClassLimitObserver.class,
+ "FIMTDDNumericAttributeClassLimitObserver");
- public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption(
- "predictionFunctionOption", 'P', "The prediction function to use.", new String[]{
- "Adaptative","Perceptron", "Target Mean"}, new String[]{
- "Adaptative","Perceptron", "Target Mean"}, 0);
+ public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption(
+ "predictionFunctionOption", 'P', "The prediction function to use.", new String[] {
+ "Adaptative", "Perceptron", "Target Mean" }, new String[] {
+ "Adaptative", "Perceptron", "Target Mean" }, 0);
- public FlagOption constantLearningRatioDecayOption = new FlagOption(
- "learningRatio_Decay_set_constant", 'd',
- "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
+ public FlagOption constantLearningRatioDecayOption = new FlagOption(
+ "learningRatio_Decay_set_constant", 'd',
+ "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
- public FloatOption learningRatioOption = new FloatOption(
- "learningRatio", 's',
- "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025);
+ public FloatOption learningRatioOption = new FloatOption(
+ "learningRatio", 's',
+ "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025);
- public MultiChoiceOption votingTypeOption = new MultiChoiceOption(
- "votingType", 'V', "Voting Type.", new String[]{
- "InverseErrorWeightedVote","UniformWeightedVote"}, new String[]{
- "InverseErrorWeightedVote","UniformWeightedVote"}, 0);
-
- public IntOption parallelismHintOption = new IntOption(
- "parallelismHint",
- 'p',
- "The number of local statistics PI to do distributed computation",
- 1, 1, Integer.MAX_VALUE);
-
- // Processor
- private AMRulesAggregatorProcessor aggregator;
+ public MultiChoiceOption votingTypeOption = new MultiChoiceOption(
+ "votingType", 'V', "Voting Type.", new String[] {
+ "InverseErrorWeightedVote", "UniformWeightedVote" }, new String[] {
+ "InverseErrorWeightedVote", "UniformWeightedVote" }, 0);
- // Stream
- private Stream resultStream;
+ public IntOption parallelismHintOption = new IntOption(
+ "parallelismHint",
+ 'p',
+ "The number of local statistics PI to do distributed computation",
+ 1, 1, Integer.MAX_VALUE);
- @Override
- public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) {
+ // Processor
+ private AMRulesAggregatorProcessor aggregator;
- this.aggregator = new AMRulesAggregatorProcessor.Builder(dataset)
- .threshold(pageHinckleyThresholdOption.getValue())
- .alpha(pageHinckleyAlphaOption.getValue())
- .changeDetection(this.DriftDetectionOption.isSet())
- .predictionFunction(predictionFunctionOption.getChosenIndex())
- .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet())
- .learningRatio(learningRatioOption.getValue())
- .splitConfidence(splitConfidenceOption.getValue())
- .tieThreshold(tieThresholdOption.getValue())
- .gracePeriod(gracePeriodOption.getValue())
- .noAnomalyDetection(noAnomalyDetectionOption.isSet())
- .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue())
- .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue())
- .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue())
- .unorderedRules(unorderedRulesOption.isSet())
- .numericObserver((FIMTDDNumericAttributeClassLimitObserver)numericObserverOption.getValue())
- .voteType(votingTypeOption.getChosenIndex())
- .build();
+ // Stream
+ private Stream resultStream;
- topologyBuilder.addProcessor(aggregator);
+ @Override
+ public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) {
- Stream statisticsStream = topologyBuilder.createStream(aggregator);
- this.resultStream = topologyBuilder.createStream(aggregator);
-
- this.aggregator.setResultStream(resultStream);
- this.aggregator.setStatisticsStream(statisticsStream);
+ this.aggregator = new AMRulesAggregatorProcessor.Builder(dataset)
+ .threshold(pageHinckleyThresholdOption.getValue())
+ .alpha(pageHinckleyAlphaOption.getValue())
+ .changeDetection(this.DriftDetectionOption.isSet())
+ .predictionFunction(predictionFunctionOption.getChosenIndex())
+ .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet())
+ .learningRatio(learningRatioOption.getValue())
+ .splitConfidence(splitConfidenceOption.getValue())
+ .tieThreshold(tieThresholdOption.getValue())
+ .gracePeriod(gracePeriodOption.getValue())
+ .noAnomalyDetection(noAnomalyDetectionOption.isSet())
+ .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue())
+ .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue())
+ .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue())
+ .unorderedRules(unorderedRulesOption.isSet())
+ .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue())
+ .voteType(votingTypeOption.getChosenIndex())
+ .build();
- AMRulesStatisticsProcessor learner = new AMRulesStatisticsProcessor.Builder(dataset)
- .splitConfidence(splitConfidenceOption.getValue())
- .tieThreshold(tieThresholdOption.getValue())
- .gracePeriod(gracePeriodOption.getValue())
- .build();
-
- topologyBuilder.addProcessor(learner, this.parallelismHintOption.getValue());
-
- topologyBuilder.connectInputKeyStream(statisticsStream, learner);
+ topologyBuilder.addProcessor(aggregator);
- Stream predicateStream = topologyBuilder.createStream(learner);
- learner.setOutputStream(predicateStream);
-
- topologyBuilder.connectInputShuffleStream(predicateStream, aggregator);
- }
+ Stream statisticsStream = topologyBuilder.createStream(aggregator);
+ this.resultStream = topologyBuilder.createStream(aggregator);
- @Override
- public Processor getInputProcessor() {
- return aggregator;
- }
+ this.aggregator.setResultStream(resultStream);
+ this.aggregator.setStatisticsStream(statisticsStream);
- @Override
- public Set<Stream> getResultStreams() {
- return ImmutableSet.of(this.resultStream);
- }
+ AMRulesStatisticsProcessor learner = new AMRulesStatisticsProcessor.Builder(dataset)
+ .splitConfidence(splitConfidenceOption.getValue())
+ .tieThreshold(tieThresholdOption.getValue())
+ .gracePeriod(gracePeriodOption.getValue())
+ .build();
+
+ topologyBuilder.addProcessor(learner, this.parallelismHintOption.getValue());
+
+ topologyBuilder.connectInputKeyStream(statisticsStream, learner);
+
+ Stream predicateStream = topologyBuilder.createStream(learner);
+ learner.setOutputStream(predicateStream);
+
+ topologyBuilder.connectInputShuffleStream(predicateStream, aggregator);
+ }
+
+ @Override
+ public Processor getInputProcessor() {
+ return aggregator;
+ }
+
+ @Override
+ public Set<Stream> getResultStreams() {
+ return ImmutableSet.of(this.resultStream);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java
index f83d6fd..48e9dbb 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java
@@ -38,472 +38,469 @@
import com.yahoo.labs.samoa.topology.Stream;
/**
- * AMRules Regressor Processor
- * is the main (and only) processor for AMRulesRegressor task.
- * It is adapted from the AMRules implementation in MOA.
+ * AMRules Regressor Processor is the main (and only) processor for
+ * AMRulesRegressor task. It is adapted from the AMRules implementation in MOA.
*
* @author Anh Thu Vu
- *
+ *
*/
public class AMRulesRegressorProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- private int processorId;
+ private int processorId;
- // Rules & default rule
- protected List<ActiveRule> ruleSet;
- protected ActiveRule defaultRule;
- protected int ruleNumberID;
- protected double[] statistics;
+ // Rules & default rule
+ protected List<ActiveRule> ruleSet;
+ protected ActiveRule defaultRule;
+ protected int ruleNumberID;
+ protected double[] statistics;
- // SAMOA Stream
- private Stream resultStream;
+ // SAMOA Stream
+ private Stream resultStream;
- // Options
- protected int pageHinckleyThreshold;
- protected double pageHinckleyAlpha;
- protected boolean driftDetection;
- protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
- protected boolean constantLearningRatioDecay;
- protected double learningRatio;
+ // Options
+ protected int pageHinckleyThreshold;
+ protected double pageHinckleyAlpha;
+ protected boolean driftDetection;
+ protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
+ protected boolean constantLearningRatioDecay;
+ protected double learningRatio;
- protected double splitConfidence;
- protected double tieThreshold;
- protected int gracePeriod;
+ protected double splitConfidence;
+ protected double tieThreshold;
+ protected int gracePeriod;
- protected boolean noAnomalyDetection;
- protected double multivariateAnomalyProbabilityThreshold;
- protected double univariateAnomalyprobabilityThreshold;
- protected int anomalyNumInstThreshold;
+ protected boolean noAnomalyDetection;
+ protected double multivariateAnomalyProbabilityThreshold;
+ protected double univariateAnomalyprobabilityThreshold;
+ protected int anomalyNumInstThreshold;
- protected boolean unorderedRules;
+ protected boolean unorderedRules;
- protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
-
- protected ErrorWeightedVote voteType;
-
- /*
- * Constructor
- */
- public AMRulesRegressorProcessor (Builder builder) {
- this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
- this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
- this.driftDetection = builder.driftDetection;
- this.predictionFunction = builder.predictionFunction;
- this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
- this.learningRatio = builder.learningRatio;
- this.splitConfidence = builder.splitConfidence;
- this.tieThreshold = builder.tieThreshold;
- this.gracePeriod = builder.gracePeriod;
-
- this.noAnomalyDetection = builder.noAnomalyDetection;
- this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
- this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
- this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
- this.unorderedRules = builder.unorderedRules;
-
- this.numericObserver = builder.numericObserver;
- this.voteType = builder.voteType;
- }
-
- /*
- * Process
- */
- @Override
- public boolean process(ContentEvent event) {
- InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
-
- // predict
- if (instanceEvent.isTesting()) {
- this.predictOnInstance(instanceEvent);
- }
-
- // train
- if (instanceEvent.isTraining()) {
- this.trainOnInstance(instanceEvent);
- }
-
- return true;
- }
-
- /*
- * Prediction
- */
- private void predictOnInstance (InstanceContentEvent instanceEvent) {
- double[] prediction = getVotesForInstance(instanceEvent.getInstance());
- ResultContentEvent rce = newResultContentEvent(prediction, instanceEvent);
- resultStream.put(rce);
- }
-
- /**
- * Helper method to generate new ResultContentEvent based on an instance and
- * its prediction result.
- * @param prediction The predicted class label from the decision tree model.
- * @param inEvent The associated instance content event
- * @return ResultContentEvent to be sent into Evaluator PI or other destination PI.
- */
- private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){
- ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent());
- rce.setClassifierIndex(this.processorId);
- rce.setEvaluationIndex(inEvent.getEvaluationIndex());
- return rce;
- }
-
- /**
- * getVotesForInstance extension of the instance method getVotesForInstance
- * in moa.classifier.java
- * returns the prediction of the instance.
- * Called in EvaluateModelRegression
- */
- private double[] getVotesForInstance(Instance instance) {
- ErrorWeightedVote errorWeightedVote=newErrorWeightedVote();
- int numberOfRulesCovering = 0;
-
- for (ActiveRule rule: ruleSet) {
- if (rule.isCovering(instance) == true){
- numberOfRulesCovering++;
- double [] vote=rule.getPrediction(instance);
- double error= rule.getCurrentError();
- errorWeightedVote.addVote(vote,error);
- if (!this.unorderedRules) { // Ordered Rules Option.
- break; // Only one rule cover the instance.
- }
- }
- }
+ protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
- if (numberOfRulesCovering == 0) {
- double [] vote=defaultRule.getPrediction(instance);
- double error= defaultRule.getCurrentError();
- errorWeightedVote.addVote(vote,error);
- }
- double[] weightedVote=errorWeightedVote.computeWeightedVote();
-
- return weightedVote;
- }
+ protected ErrorWeightedVote voteType;
- public ErrorWeightedVote newErrorWeightedVote() {
- return voteType.getACopy();
- }
-
- /*
- * Training
- */
- private void trainOnInstance (InstanceContentEvent instanceEvent) {
- this.trainOnInstanceImpl(instanceEvent.getInstance());
- }
- public void trainOnInstanceImpl(Instance instance) {
- /**
- * AMRules Algorithm
- *
- //For each rule in the rule set
- //If rule covers the instance
- //if the instance is not an anomaly
- //Update Change Detection Tests
- //Compute prediction error
- //Call PHTest
- //If change is detected then
- //Remove rule
- //Else
- //Update sufficient statistics of rule
- //If number of examples in rule > Nmin
- //Expand rule
- //If ordered set then
- //break
- //If none of the rule covers the instance
- //Update sufficient statistics of default rule
- //If number of examples in default rule is multiple of Nmin
- //Expand default rule and add it to the set of rules
- //Reset the default rule
- */
- boolean rulesCoveringInstance = false;
- Iterator<ActiveRule> ruleIterator= this.ruleSet.iterator();
- while (ruleIterator.hasNext()) {
- ActiveRule rule = ruleIterator.next();
- if (rule.isCovering(instance) == true) {
- rulesCoveringInstance = true;
- if (isAnomaly(instance, rule) == false) {
- //Update Change Detection Tests
- double error = rule.computeError(instance); //Use adaptive mode error
- boolean changeDetected = ((RuleActiveRegressionNode)rule.getLearningNode()).updateChangeDetection(error);
- if (changeDetected == true) {
- ruleIterator.remove();
- } else {
- rule.updateStatistics(instance);
- if (rule.getInstancesSeen() % this.gracePeriod == 0.0) {
- if (rule.tryToExpand(this.splitConfidence, this.tieThreshold) ) {
- rule.split();
- }
- }
- }
- if (!this.unorderedRules)
- break;
- }
- }
- }
+ /*
+ * Constructor
+ */
+ public AMRulesRegressorProcessor(Builder builder) {
+ this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
+ this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
+ this.driftDetection = builder.driftDetection;
+ this.predictionFunction = builder.predictionFunction;
+ this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
+ this.learningRatio = builder.learningRatio;
+ this.splitConfidence = builder.splitConfidence;
+ this.tieThreshold = builder.tieThreshold;
+ this.gracePeriod = builder.gracePeriod;
- if (rulesCoveringInstance == false){
- defaultRule.updateStatistics(instance);
- if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
- if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
- ActiveRule newDefaultRule=newRule(defaultRule.getRuleNumberID(),(RuleActiveRegressionNode)defaultRule.getLearningNode(),
- ((RuleActiveRegressionNode)defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); //other branch
- defaultRule.split();
- defaultRule.setRuleNumberID(++ruleNumberID);
- this.ruleSet.add(this.defaultRule);
-
- defaultRule=newDefaultRule;
+ this.noAnomalyDetection = builder.noAnomalyDetection;
+ this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
+ this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
+ this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
+ this.unorderedRules = builder.unorderedRules;
- }
- }
- }
- }
+ this.numericObserver = builder.numericObserver;
+ this.voteType = builder.voteType;
+ }
- /**
- * Method to verify if the instance is an anomaly.
- * @param instance
- * @param rule
- * @return
- */
- private boolean isAnomaly(Instance instance, ActiveRule rule) {
- //AMRUles is equipped with anomaly detection. If on, compute the anomaly value.
- boolean isAnomaly = false;
- if (this.noAnomalyDetection == false){
- if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
- isAnomaly = rule.isAnomaly(instance,
- this.univariateAnomalyprobabilityThreshold,
- this.multivariateAnomalyProbabilityThreshold,
- this.anomalyNumInstThreshold);
- }
- }
- return isAnomaly;
- }
-
- /*
- * Create new rules
- */
- // TODO check this after finish rule, LN
- private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
- ActiveRule r=newRule(ID);
+ /*
+ * Process
+ */
+ @Override
+ public boolean process(ContentEvent event) {
+ InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
- if (node!=null)
- {
- if(node.getPerceptron()!=null)
- {
- r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
- r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
- }
- if (statistics==null)
- {
- double mean;
- if(node.getNodeStatistics().getValue(0)>0){
- mean=node.getNodeStatistics().getValue(1)/node.getNodeStatistics().getValue(0);
- r.getLearningNode().getTargetMean().reset(mean, 1);
- }
- }
- }
- if (statistics!=null && ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean()!=null)
- {
- double mean;
- if(statistics[0]>0){
- mean=statistics[1]/statistics[0];
- ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean().reset(mean, (long)statistics[0]);
- }
- }
- return r;
- }
-
- private ActiveRule newRule(int ID) {
- ActiveRule r=new ActiveRule.Builder().
- threshold(this.pageHinckleyThreshold).
- alpha(this.pageHinckleyAlpha).
- changeDetection(this.driftDetection).
- predictionFunction(this.predictionFunction).
- statistics(new double[3]).
- learningRatio(this.learningRatio).
- numericObserver(numericObserver).
- id(ID).build();
- return r;
- }
-
- /*
- * Init processor
- */
- @Override
- public void onCreate(int id) {
- this.processorId = id;
- this.statistics= new double[]{0.0,0,0};
- this.ruleNumberID=0;
- this.defaultRule = newRule(++this.ruleNumberID);
-
- this.ruleSet = new LinkedList<ActiveRule>();
- }
-
- /*
- * Clone processor
- */
- @Override
- public Processor newProcessor(Processor p) {
- AMRulesRegressorProcessor oldProcessor = (AMRulesRegressorProcessor) p;
- Builder builder = new Builder(oldProcessor);
- AMRulesRegressorProcessor newProcessor = builder.build();
- newProcessor.resultStream = oldProcessor.resultStream;
- return newProcessor;
- }
-
- /*
- * Output stream
- */
- public void setResultStream(Stream resultStream) {
- this.resultStream = resultStream;
- }
-
- public Stream getResultStream() {
- return this.resultStream;
- }
-
- /*
- * Others
- */
- public boolean isRandomizable() {
- return true;
+ // predict
+ if (instanceEvent.isTesting()) {
+ this.predictOnInstance(instanceEvent);
}
-
- /*
- * Builder
- */
- public static class Builder {
- private int pageHinckleyThreshold;
- private double pageHinckleyAlpha;
- private boolean driftDetection;
- private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
- private boolean constantLearningRatioDecay;
- private double learningRatio;
- private double splitConfidence;
- private double tieThreshold;
- private int gracePeriod;
-
- private boolean noAnomalyDetection;
- private double multivariateAnomalyProbabilityThreshold;
- private double univariateAnomalyprobabilityThreshold;
- private int anomalyNumInstThreshold;
-
- private boolean unorderedRules;
-
- private FIMTDDNumericAttributeClassLimitObserver numericObserver;
- private ErrorWeightedVote voteType;
-
- private Instances dataset;
-
- public Builder(Instances dataset){
- this.dataset = dataset;
- }
-
- public Builder(AMRulesRegressorProcessor processor) {
- this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
- this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
- this.driftDetection = processor.driftDetection;
- this.predictionFunction = processor.predictionFunction;
- this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
- this.learningRatio = processor.learningRatio;
- this.splitConfidence = processor.splitConfidence;
- this.tieThreshold = processor.tieThreshold;
- this.gracePeriod = processor.gracePeriod;
-
- this.noAnomalyDetection = processor.noAnomalyDetection;
- this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold;
- this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold;
- this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold;
- this.unorderedRules = processor.unorderedRules;
-
- this.numericObserver = processor.numericObserver;
- this.voteType = processor.voteType;
- }
-
- public Builder threshold(int threshold) {
- this.pageHinckleyThreshold = threshold;
- return this;
- }
-
- public Builder alpha(double alpha) {
- this.pageHinckleyAlpha = alpha;
- return this;
- }
-
- public Builder changeDetection(boolean changeDetection) {
- this.driftDetection = changeDetection;
- return this;
- }
-
- public Builder predictionFunction(int predictionFunction) {
- this.predictionFunction = predictionFunction;
- return this;
- }
-
- public Builder constantLearningRatioDecay(boolean constantDecay) {
- this.constantLearningRatioDecay = constantDecay;
- return this;
- }
-
- public Builder learningRatio(double learningRatio) {
- this.learningRatio = learningRatio;
- return this;
- }
-
- public Builder splitConfidence(double splitConfidence) {
- this.splitConfidence = splitConfidence;
- return this;
- }
-
- public Builder tieThreshold(double tieThreshold) {
- this.tieThreshold = tieThreshold;
- return this;
- }
-
- public Builder gracePeriod(int gracePeriod) {
- this.gracePeriod = gracePeriod;
- return this;
- }
-
- public Builder noAnomalyDetection(boolean noAnomalyDetection) {
- this.noAnomalyDetection = noAnomalyDetection;
- return this;
- }
-
- public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
- this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
- return this;
- }
-
- public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
- this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
- return this;
- }
-
- public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
- this.anomalyNumInstThreshold = anomalyNumInstThreshold;
- return this;
- }
-
- public Builder unorderedRules(boolean unorderedRules) {
- this.unorderedRules = unorderedRules;
- return this;
- }
-
- public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
- this.numericObserver = numericObserver;
- return this;
- }
-
- public Builder voteType(ErrorWeightedVote voteType) {
- this.voteType = voteType;
- return this;
- }
-
- public AMRulesRegressorProcessor build() {
- return new AMRulesRegressorProcessor(this);
- }
- }
+
+ // train
+ if (instanceEvent.isTraining()) {
+ this.trainOnInstance(instanceEvent);
+ }
+
+ return true;
+ }
+
+ /*
+ * Prediction
+ */
+ private void predictOnInstance(InstanceContentEvent instanceEvent) {
+ double[] prediction = getVotesForInstance(instanceEvent.getInstance());
+ ResultContentEvent rce = newResultContentEvent(prediction, instanceEvent);
+ resultStream.put(rce);
+ }
+
+ /**
+ * Helper method to generate new ResultContentEvent based on an instance and
+ * its prediction result.
+ *
+ * @param prediction
+ * The predicted class label from the decision tree model.
+ * @param inEvent
+ * The associated instance content event
+ * @return ResultContentEvent to be sent into Evaluator PI or other
+ * destination PI.
+ */
+ private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
+ ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
+ inEvent.getClassId(), prediction, inEvent.isLastEvent());
+ rce.setClassifierIndex(this.processorId);
+ rce.setEvaluationIndex(inEvent.getEvaluationIndex());
+ return rce;
+ }
+
+ /**
+ * getVotesForInstance extension of the instance method getVotesForInstance in
+ * moa.classifier.java returns the prediction of the instance. Called in
+ * EvaluateModelRegression
+ */
+ private double[] getVotesForInstance(Instance instance) {
+ ErrorWeightedVote errorWeightedVote = newErrorWeightedVote();
+ int numberOfRulesCovering = 0;
+
+ for (ActiveRule rule : ruleSet) {
+ if (rule.isCovering(instance) == true) {
+ numberOfRulesCovering++;
+ double[] vote = rule.getPrediction(instance);
+ double error = rule.getCurrentError();
+ errorWeightedVote.addVote(vote, error);
+ if (!this.unorderedRules) { // Ordered Rules Option.
+ break; // Only one rule cover the instance.
+ }
+ }
+ }
+
+ if (numberOfRulesCovering == 0) {
+ double[] vote = defaultRule.getPrediction(instance);
+ double error = defaultRule.getCurrentError();
+ errorWeightedVote.addVote(vote, error);
+ }
+ double[] weightedVote = errorWeightedVote.computeWeightedVote();
+
+ return weightedVote;
+ }
+
+ public ErrorWeightedVote newErrorWeightedVote() {
+ return voteType.getACopy();
+ }
+
+ /*
+ * Training
+ */
+ private void trainOnInstance(InstanceContentEvent instanceEvent) {
+ this.trainOnInstanceImpl(instanceEvent.getInstance());
+ }
+
+ public void trainOnInstanceImpl(Instance instance) {
+ /**
+ * AMRules Algorithm
+ *
+ * //For each rule in the rule set //If rule covers the instance //if the
+ * instance is not an anomaly //Update Change Detection Tests //Compute
+ * prediction error //Call PHTest //If change is detected then //Remove rule
+ * //Else //Update sufficient statistics of rule //If number of examples in
+ * rule > Nmin //Expand rule //If ordered set then //break //If none of the
+ * rule covers the instance //Update sufficient statistics of default rule
+ * //If number of examples in default rule is multiple of Nmin //Expand
+ * default rule and add it to the set of rules //Reset the default rule
+ */
+ boolean rulesCoveringInstance = false;
+ Iterator<ActiveRule> ruleIterator = this.ruleSet.iterator();
+ while (ruleIterator.hasNext()) {
+ ActiveRule rule = ruleIterator.next();
+ if (rule.isCovering(instance) == true) {
+ rulesCoveringInstance = true;
+ if (isAnomaly(instance, rule) == false) {
+ // Update Change Detection Tests
+ double error = rule.computeError(instance); // Use adaptive mode error
+ boolean changeDetected = ((RuleActiveRegressionNode) rule.getLearningNode()).updateChangeDetection(error);
+ if (changeDetected == true) {
+ ruleIterator.remove();
+ } else {
+ rule.updateStatistics(instance);
+ if (rule.getInstancesSeen() % this.gracePeriod == 0.0) {
+ if (rule.tryToExpand(this.splitConfidence, this.tieThreshold)) {
+ rule.split();
+ }
+ }
+ }
+ if (!this.unorderedRules)
+ break;
+ }
+ }
+ }
+
+ if (rulesCoveringInstance == false) {
+ defaultRule.updateStatistics(instance);
+ if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
+ if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
+ ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(),
+ (RuleActiveRegressionNode) defaultRule.getLearningNode(),
+ ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other
+ // branch
+ defaultRule.split();
+ defaultRule.setRuleNumberID(++ruleNumberID);
+ this.ruleSet.add(this.defaultRule);
+
+ defaultRule = newDefaultRule;
+
+ }
+ }
+ }
+ }
+
+ /**
+ * Method to verify if the instance is an anomaly.
+ *
+ * @param instance
+ * @param rule
+ * @return
+ */
+ private boolean isAnomaly(Instance instance, ActiveRule rule) {
+ // AMRUles is equipped with anomaly detection. If on, compute the anomaly
+ // value.
+ boolean isAnomaly = false;
+ if (this.noAnomalyDetection == false) {
+ if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
+ isAnomaly = rule.isAnomaly(instance,
+ this.univariateAnomalyprobabilityThreshold,
+ this.multivariateAnomalyProbabilityThreshold,
+ this.anomalyNumInstThreshold);
+ }
+ }
+ return isAnomaly;
+ }
+
+ /*
+ * Create new rules
+ */
+ // TODO check this after finish rule, LN
+ private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
+ ActiveRule r = newRule(ID);
+
+ if (node != null)
+ {
+ if (node.getPerceptron() != null)
+ {
+ r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
+ r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
+ }
+ if (statistics == null)
+ {
+ double mean;
+ if (node.getNodeStatistics().getValue(0) > 0) {
+ mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0);
+ r.getLearningNode().getTargetMean().reset(mean, 1);
+ }
+ }
+ }
+ if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null)
+ {
+ double mean;
+ if (statistics[0] > 0) {
+ mean = statistics[1] / statistics[0];
+ ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]);
+ }
+ }
+ return r;
+ }
+
+ private ActiveRule newRule(int ID) {
+ ActiveRule r = new ActiveRule.Builder().
+ threshold(this.pageHinckleyThreshold).
+ alpha(this.pageHinckleyAlpha).
+ changeDetection(this.driftDetection).
+ predictionFunction(this.predictionFunction).
+ statistics(new double[3]).
+ learningRatio(this.learningRatio).
+ numericObserver(numericObserver).
+ id(ID).build();
+ return r;
+ }
+
+ /*
+ * Init processor
+ */
+ @Override
+ public void onCreate(int id) {
+ this.processorId = id;
+ this.statistics = new double[] { 0.0, 0, 0 };
+ this.ruleNumberID = 0;
+ this.defaultRule = newRule(++this.ruleNumberID);
+
+ this.ruleSet = new LinkedList<ActiveRule>();
+ }
+
+ /*
+ * Clone processor
+ */
+ @Override
+ public Processor newProcessor(Processor p) {
+ AMRulesRegressorProcessor oldProcessor = (AMRulesRegressorProcessor) p;
+ Builder builder = new Builder(oldProcessor);
+ AMRulesRegressorProcessor newProcessor = builder.build();
+ newProcessor.resultStream = oldProcessor.resultStream;
+ return newProcessor;
+ }
+
+ /*
+ * Output stream
+ */
+ public void setResultStream(Stream resultStream) {
+ this.resultStream = resultStream;
+ }
+
+ public Stream getResultStream() {
+ return this.resultStream;
+ }
+
+ /*
+ * Others
+ */
+ public boolean isRandomizable() {
+ return true;
+ }
+
+ /*
+ * Builder
+ */
+ public static class Builder {
+ private int pageHinckleyThreshold;
+ private double pageHinckleyAlpha;
+ private boolean driftDetection;
+ private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
+ private boolean constantLearningRatioDecay;
+ private double learningRatio;
+ private double splitConfidence;
+ private double tieThreshold;
+ private int gracePeriod;
+
+ private boolean noAnomalyDetection;
+ private double multivariateAnomalyProbabilityThreshold;
+ private double univariateAnomalyprobabilityThreshold;
+ private int anomalyNumInstThreshold;
+
+ private boolean unorderedRules;
+
+ private FIMTDDNumericAttributeClassLimitObserver numericObserver;
+ private ErrorWeightedVote voteType;
+
+ private Instances dataset;
+
+ public Builder(Instances dataset) {
+ this.dataset = dataset;
+ }
+
+ public Builder(AMRulesRegressorProcessor processor) {
+ this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
+ this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
+ this.driftDetection = processor.driftDetection;
+ this.predictionFunction = processor.predictionFunction;
+ this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
+ this.learningRatio = processor.learningRatio;
+ this.splitConfidence = processor.splitConfidence;
+ this.tieThreshold = processor.tieThreshold;
+ this.gracePeriod = processor.gracePeriod;
+
+ this.noAnomalyDetection = processor.noAnomalyDetection;
+ this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold;
+ this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold;
+ this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold;
+ this.unorderedRules = processor.unorderedRules;
+
+ this.numericObserver = processor.numericObserver;
+ this.voteType = processor.voteType;
+ }
+
+ public Builder threshold(int threshold) {
+ this.pageHinckleyThreshold = threshold;
+ return this;
+ }
+
+ public Builder alpha(double alpha) {
+ this.pageHinckleyAlpha = alpha;
+ return this;
+ }
+
+ public Builder changeDetection(boolean changeDetection) {
+ this.driftDetection = changeDetection;
+ return this;
+ }
+
+ public Builder predictionFunction(int predictionFunction) {
+ this.predictionFunction = predictionFunction;
+ return this;
+ }
+
+ public Builder constantLearningRatioDecay(boolean constantDecay) {
+ this.constantLearningRatioDecay = constantDecay;
+ return this;
+ }
+
+ public Builder learningRatio(double learningRatio) {
+ this.learningRatio = learningRatio;
+ return this;
+ }
+
+ public Builder splitConfidence(double splitConfidence) {
+ this.splitConfidence = splitConfidence;
+ return this;
+ }
+
+ public Builder tieThreshold(double tieThreshold) {
+ this.tieThreshold = tieThreshold;
+ return this;
+ }
+
+ public Builder gracePeriod(int gracePeriod) {
+ this.gracePeriod = gracePeriod;
+ return this;
+ }
+
+ public Builder noAnomalyDetection(boolean noAnomalyDetection) {
+ this.noAnomalyDetection = noAnomalyDetection;
+ return this;
+ }
+
+ public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
+ this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
+ return this;
+ }
+
+ public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
+ this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
+ return this;
+ }
+
+ public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
+ this.anomalyNumInstThreshold = anomalyNumInstThreshold;
+ return this;
+ }
+
+ public Builder unorderedRules(boolean unorderedRules) {
+ this.unorderedRules = unorderedRules;
+ return this;
+ }
+
+ public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
+ this.numericObserver = numericObserver;
+ return this;
+ }
+
+ public Builder voteType(ErrorWeightedVote voteType) {
+ this.voteType = voteType;
+ return this;
+ }
+
+ public AMRulesRegressorProcessor build() {
+ return new AMRulesRegressorProcessor(this);
+ }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/ActiveRule.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/ActiveRule.java
index b6fba99..5d02079 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/ActiveRule.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/ActiveRule.java
@@ -28,199 +28,202 @@
import com.yahoo.labs.samoa.moa.classifiers.rules.core.conditionaltests.NumericAttributeBinaryRulePredicate;
/**
- * ActiveRule is a LearningRule that actively update its LearningNode
- * with incoming instances.
+ * ActiveRule is a LearningRule that actively update its LearningNode with
+ * incoming instances.
*
* @author Anh Thu Vu
- *
+ *
*/
public class ActiveRule extends LearningRule {
-
- private static final long serialVersionUID = 1L;
- private double[] statisticsOtherBranchSplit;
+ private static final long serialVersionUID = 1L;
- private Builder builder;
-
- private RuleActiveRegressionNode learningNode;
-
- private RuleSplitNode lastUpdatedRuleSplitNode;
-
- /*
- * Constructor with Builder
- */
- public ActiveRule() {
- super();
- this.builder = null;
- this.learningNode = null;
- this.ruleNumberID = 0;
- }
- public ActiveRule(Builder builder) {
- super();
- this.setBuilder(builder);
- this.learningNode = newRuleActiveLearningNode(builder);
- //JD - use builder ID to set ruleNumberID
- this.ruleNumberID=builder.id;
- }
+ private double[] statisticsOtherBranchSplit;
- private RuleActiveRegressionNode newRuleActiveLearningNode(Builder builder) {
- return new RuleActiveRegressionNode(builder);
- }
+ private Builder builder;
- /*
- * Setters & getters
- */
- public Builder getBuilder() {
- return builder;
- }
+ private RuleActiveRegressionNode learningNode;
- public void setBuilder(Builder builder) {
- this.builder = builder;
- }
-
- @Override
- public RuleRegressionNode getLearningNode() {
- return this.learningNode;
- }
+ private RuleSplitNode lastUpdatedRuleSplitNode;
- @Override
- public void setLearningNode(RuleRegressionNode learningNode) {
- this.learningNode = (RuleActiveRegressionNode) learningNode;
- }
-
- public double[] statisticsOtherBranchSplit() {
- return this.statisticsOtherBranchSplit;
- }
-
- public RuleSplitNode getLastUpdatedRuleSplitNode() {
- return this.lastUpdatedRuleSplitNode;
- }
+ /*
+ * Constructor with Builder
+ */
+ public ActiveRule() {
+ super();
+ this.builder = null;
+ this.learningNode = null;
+ this.ruleNumberID = 0;
+ }
- /*
- * Builder
- */
- public static class Builder implements Serializable {
+ public ActiveRule(Builder builder) {
+ super();
+ this.setBuilder(builder);
+ this.learningNode = newRuleActiveLearningNode(builder);
+ // JD - use builder ID to set ruleNumberID
+ this.ruleNumberID = builder.id;
+ }
- private static final long serialVersionUID = 1712887264918475622L;
- protected boolean changeDetection;
- protected boolean usePerceptron;
- protected double threshold;
- protected double alpha;
- protected int predictionFunction;
- protected boolean constantLearningRatioDecay;
- protected double learningRatio;
+ private RuleActiveRegressionNode newRuleActiveLearningNode(Builder builder) {
+ return new RuleActiveRegressionNode(builder);
+ }
- protected double[] statistics;
+ /*
+ * Setters & getters
+ */
+ public Builder getBuilder() {
+ return builder;
+ }
- protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
-
- protected double lastTargetMean;
+ public void setBuilder(Builder builder) {
+ this.builder = builder;
+ }
- public int id;
+ @Override
+ public RuleRegressionNode getLearningNode() {
+ return this.learningNode;
+ }
- public Builder() {
- }
+ @Override
+ public void setLearningNode(RuleRegressionNode learningNode) {
+ this.learningNode = (RuleActiveRegressionNode) learningNode;
+ }
- public Builder changeDetection(boolean changeDetection) {
- this.changeDetection = changeDetection;
- return this;
- }
+ public double[] statisticsOtherBranchSplit() {
+ return this.statisticsOtherBranchSplit;
+ }
- public Builder threshold(double threshold) {
- this.threshold = threshold;
- return this;
- }
+ public RuleSplitNode getLastUpdatedRuleSplitNode() {
+ return this.lastUpdatedRuleSplitNode;
+ }
- public Builder alpha(double alpha) {
- this.alpha = alpha;
- return this;
- }
+ /*
+ * Builder
+ */
+ public static class Builder implements Serializable {
- public Builder predictionFunction(int predictionFunction) {
- this.predictionFunction = predictionFunction;
- return this;
- }
+ private static final long serialVersionUID = 1712887264918475622L;
+ protected boolean changeDetection;
+ protected boolean usePerceptron;
+ protected double threshold;
+ protected double alpha;
+ protected int predictionFunction;
+ protected boolean constantLearningRatioDecay;
+ protected double learningRatio;
- public Builder statistics(double[] statistics) {
- this.statistics = statistics;
- return this;
- }
-
- public Builder constantLearningRatioDecay(boolean constantLearningRatioDecay) {
- this.constantLearningRatioDecay = constantLearningRatioDecay;
- return this;
- }
-
- public Builder learningRatio(double learningRatio) {
- this.learningRatio = learningRatio;
- return this;
- }
-
- public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
- this.numericObserver = numericObserver;
- return this;
- }
+ protected double[] statistics;
- public Builder id(int id) {
- this.id = id;
- return this;
- }
- public ActiveRule build() {
- return new ActiveRule(this);
- }
+ protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
- }
+ protected double lastTargetMean;
- /**
- * Try to Expand method.
- * @param splitConfidence
- * @param tieThreshold
- * @return
- */
- public boolean tryToExpand(double splitConfidence, double tieThreshold) {
+ public int id;
- boolean shouldSplit= this.learningNode.tryToExpand(splitConfidence, tieThreshold);
- return shouldSplit;
+ public Builder() {
+ }
- }
-
- //JD: Only call after tryToExpand returning true
- public void split()
- {
- //this.statisticsOtherBranchSplit = this.learningNode.getStatisticsOtherBranchSplit();
- //create a split node,
- int splitIndex = this.learningNode.getSplitIndex();
- InstanceConditionalTest st=this.learningNode.getBestSuggestion().splitTest;
- if(st instanceof NumericAttributeBinaryTest) {
- NumericAttributeBinaryTest splitTest = (NumericAttributeBinaryTest) st;
- NumericAttributeBinaryRulePredicate predicate = new NumericAttributeBinaryRulePredicate(
- splitTest.getAttsTestDependsOn()[0], splitTest.getSplitValue(),
- splitIndex + 1);
- lastUpdatedRuleSplitNode = new RuleSplitNode(predicate, this.learningNode.getStatisticsBranchSplit() );
- if (this.nodeListAdd(lastUpdatedRuleSplitNode)) {
- // create a new learning node
- RuleActiveRegressionNode newLearningNode = newRuleActiveLearningNode(this.getBuilder().statistics(this.learningNode.getStatisticsNewRuleActiveLearningNode()));
- newLearningNode.initialize(this.learningNode);
- this.learningNode = newLearningNode;
- }
- }
- else
- throw new UnsupportedOperationException("AMRules (currently) only supports numerical attributes.");
- }
+ public Builder changeDetection(boolean changeDetection) {
+ this.changeDetection = changeDetection;
+ return this;
+ }
-
-
-// protected void debug(String string, int level) {
-// if (this.amRules.VerbosityOption.getValue()>=level) {
-// System.out.println(string);
-// }
-//}
-
- /**
- * MOA GUI output
- */
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- }
+ public Builder threshold(double threshold) {
+ this.threshold = threshold;
+ return this;
+ }
+
+ public Builder alpha(double alpha) {
+ this.alpha = alpha;
+ return this;
+ }
+
+ public Builder predictionFunction(int predictionFunction) {
+ this.predictionFunction = predictionFunction;
+ return this;
+ }
+
+ public Builder statistics(double[] statistics) {
+ this.statistics = statistics;
+ return this;
+ }
+
+ public Builder constantLearningRatioDecay(boolean constantLearningRatioDecay) {
+ this.constantLearningRatioDecay = constantLearningRatioDecay;
+ return this;
+ }
+
+ public Builder learningRatio(double learningRatio) {
+ this.learningRatio = learningRatio;
+ return this;
+ }
+
+ public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
+ this.numericObserver = numericObserver;
+ return this;
+ }
+
+ public Builder id(int id) {
+ this.id = id;
+ return this;
+ }
+
+ public ActiveRule build() {
+ return new ActiveRule(this);
+ }
+
+ }
+
+ /**
+ * Try to Expand method.
+ *
+ * @param splitConfidence
+ * @param tieThreshold
+ * @return
+ */
+ public boolean tryToExpand(double splitConfidence, double tieThreshold) {
+
+ boolean shouldSplit = this.learningNode.tryToExpand(splitConfidence, tieThreshold);
+ return shouldSplit;
+
+ }
+
+ // JD: Only call after tryToExpand returning true
+ public void split()
+ {
+ // this.statisticsOtherBranchSplit =
+ // this.learningNode.getStatisticsOtherBranchSplit();
+ // create a split node,
+ int splitIndex = this.learningNode.getSplitIndex();
+ InstanceConditionalTest st = this.learningNode.getBestSuggestion().splitTest;
+ if (st instanceof NumericAttributeBinaryTest) {
+ NumericAttributeBinaryTest splitTest = (NumericAttributeBinaryTest) st;
+ NumericAttributeBinaryRulePredicate predicate = new NumericAttributeBinaryRulePredicate(
+ splitTest.getAttsTestDependsOn()[0], splitTest.getSplitValue(),
+ splitIndex + 1);
+ lastUpdatedRuleSplitNode = new RuleSplitNode(predicate, this.learningNode.getStatisticsBranchSplit());
+ if (this.nodeListAdd(lastUpdatedRuleSplitNode)) {
+ // create a new learning node
+ RuleActiveRegressionNode newLearningNode = newRuleActiveLearningNode(this.getBuilder().statistics(
+ this.learningNode.getStatisticsNewRuleActiveLearningNode()));
+ newLearningNode.initialize(this.learningNode);
+ this.learningNode = newLearningNode;
+ }
+ }
+ else
+ throw new UnsupportedOperationException("AMRules (currently) only supports numerical attributes.");
+ }
+
+ // protected void debug(String string, int level) {
+ // if (this.amRules.VerbosityOption.getValue()>=level) {
+ // System.out.println(string);
+ // }
+ // }
+
+ /**
+ * MOA GUI output
+ */
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/LearningRule.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/LearningRule.java
index 4c05632..b380ae8 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/LearningRule.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/LearningRule.java
@@ -28,95 +28,95 @@
* Rule with LearningNode (statistical data).
*
* @author Anh Thu Vu
- *
+ *
*/
public abstract class LearningRule extends Rule {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 1L;
-
- /*
- * Constructor
- */
- public LearningRule() {
- super();
- }
-
- /*
- * LearningNode
- */
- public abstract RuleRegressionNode getLearningNode();
+ private static final long serialVersionUID = 1L;
- public abstract void setLearningNode(RuleRegressionNode learningNode);
-
- /*
- * No. of instances seen
- */
- public long getInstancesSeen() {
- return this.getLearningNode().getInstancesSeen();
- }
+ /*
+ * Constructor
+ */
+ public LearningRule() {
+ super();
+ }
- /*
- * Error and change detection
- */
- public double computeError(Instance instance) {
- return this.getLearningNode().computeError(instance);
- }
+ /*
+ * LearningNode
+ */
+ public abstract RuleRegressionNode getLearningNode();
+ public abstract void setLearningNode(RuleRegressionNode learningNode);
- /*
- * Prediction
- */
- public double[] getPrediction(Instance instance, int mode) {
- return this.getLearningNode().getPrediction(instance, mode);
- }
+ /*
+ * No. of instances seen
+ */
+ public long getInstancesSeen() {
+ return this.getLearningNode().getInstancesSeen();
+ }
- public double[] getPrediction(Instance instance) {
- return this.getLearningNode().getPrediction(instance);
- }
-
- public double getCurrentError() {
- return this.getLearningNode().getCurrentError();
- }
-
- /*
- * Anomaly detection
- */
- public boolean isAnomaly(Instance instance,
- double uniVariateAnomalyProbabilityThreshold,
- double multiVariateAnomalyProbabilityThreshold,
- int numberOfInstanceesForAnomaly) {
- return this.getLearningNode().isAnomaly(instance, uniVariateAnomalyProbabilityThreshold,
- multiVariateAnomalyProbabilityThreshold,
- numberOfInstanceesForAnomaly);
- }
-
- /*
- * Update
- */
- public void updateStatistics(Instance instance) {
- this.getLearningNode().updateStatistics(instance);
- }
-
- public String printRule() {
- StringBuilder out = new StringBuilder();
- int indent = 1;
- StringUtils.appendIndented(out, indent, "Rule Nr." + this.ruleNumberID + " Instances seen:" + this.getLearningNode().getInstancesSeen() + "\n"); // AC
- for (RuleSplitNode node : nodeList) {
- StringUtils.appendIndented(out, indent, node.getSplitTest().toString());
- StringUtils.appendIndented(out, indent, " ");
- StringUtils.appendIndented(out, indent, node.toString());
- }
- DoubleVector pred = new DoubleVector(this.getLearningNode().getSimplePrediction());
- StringUtils.appendIndented(out, 0, " --> y: " + pred.toString());
- StringUtils.appendNewline(out);
+ /*
+ * Error and change detection
+ */
+ public double computeError(Instance instance) {
+ return this.getLearningNode().computeError(instance);
+ }
- if(getLearningNode().perceptron!=null){
- ((RuleActiveRegressionNode)this.getLearningNode()).perceptron.getModelDescription(out,0);
- StringUtils.appendNewline(out);
- }
- return(out.toString());
- }
+ /*
+ * Prediction
+ */
+ public double[] getPrediction(Instance instance, int mode) {
+ return this.getLearningNode().getPrediction(instance, mode);
+ }
+
+ public double[] getPrediction(Instance instance) {
+ return this.getLearningNode().getPrediction(instance);
+ }
+
+ public double getCurrentError() {
+ return this.getLearningNode().getCurrentError();
+ }
+
+ /*
+ * Anomaly detection
+ */
+ public boolean isAnomaly(Instance instance,
+ double uniVariateAnomalyProbabilityThreshold,
+ double multiVariateAnomalyProbabilityThreshold,
+ int numberOfInstanceesForAnomaly) {
+ return this.getLearningNode().isAnomaly(instance, uniVariateAnomalyProbabilityThreshold,
+ multiVariateAnomalyProbabilityThreshold,
+ numberOfInstanceesForAnomaly);
+ }
+
+ /*
+ * Update
+ */
+ public void updateStatistics(Instance instance) {
+ this.getLearningNode().updateStatistics(instance);
+ }
+
+ public String printRule() {
+ StringBuilder out = new StringBuilder();
+ int indent = 1;
+ StringUtils.appendIndented(out, indent, "Rule Nr." + this.ruleNumberID + " Instances seen:"
+ + this.getLearningNode().getInstancesSeen() + "\n"); // AC
+ for (RuleSplitNode node : nodeList) {
+ StringUtils.appendIndented(out, indent, node.getSplitTest().toString());
+ StringUtils.appendIndented(out, indent, " ");
+ StringUtils.appendIndented(out, indent, node.toString());
+ }
+ DoubleVector pred = new DoubleVector(this.getLearningNode().getSimplePrediction());
+ StringUtils.appendIndented(out, 0, " --> y: " + pred.toString());
+ StringUtils.appendNewline(out);
+
+ if (getLearningNode().perceptron != null) {
+ ((RuleActiveRegressionNode) this.getLearningNode()).perceptron.getModelDescription(out, 0);
+ StringUtils.appendNewline(out);
+ }
+ return (out.toString());
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/NonLearningRule.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/NonLearningRule.java
index df5b9f9..7679dc3 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/NonLearningRule.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/NonLearningRule.java
@@ -24,28 +24,28 @@
* The most basic rule: inherit from Rule the ID and list of features.
*
* @author Anh Thu Vu
- *
+ *
*/
/*
* This branch (Non-learning rule) was created for an old implementation.
- * Probably should remove None-Learning and Learning Rule classes,
- * merge Rule with LearningRule.
+ * Probably should remove None-Learning and Learning Rule classes, merge Rule
+ * with LearningRule.
*/
public class NonLearningRule extends Rule {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -1210907339230307784L;
+ private static final long serialVersionUID = -1210907339230307784L;
- public NonLearningRule(ActiveRule rule) {
- this.nodeList = rule.nodeList;
- this.ruleNumberID = rule.ruleNumberID;
- }
+ public NonLearningRule(ActiveRule rule) {
+ this.nodeList = rule.nodeList;
+ this.ruleNumberID = rule.ruleNumberID;
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // do nothing
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // do nothing
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/PassiveRule.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/PassiveRule.java
index 8281d45..f49d0d9 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/PassiveRule.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/PassiveRule.java
@@ -23,49 +23,49 @@
import java.util.LinkedList;
/**
- * PassiveRule is a LearningRule that update its LearningNode
- * with the received new LearningNode.
+ * PassiveRule is a LearningRule that update its LearningNode with the received
+ * new LearningNode.
*
* @author Anh Thu Vu
- *
+ *
*/
public class PassiveRule extends LearningRule {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -5551571895910530275L;
-
- private RulePassiveRegressionNode learningNode;
+ private static final long serialVersionUID = -5551571895910530275L;
- /*
- * Constructor to turn an ActiveRule into a PassiveRule
- */
- public PassiveRule(ActiveRule rule) {
- this.nodeList = new LinkedList<>();
- for (RuleSplitNode node:rule.nodeList) {
- this.nodeList.add(node.getACopy());
- }
-
- this.learningNode = new RulePassiveRegressionNode(rule.getLearningNode());
- this.ruleNumberID = rule.ruleNumberID;
- }
-
- @Override
- public RuleRegressionNode getLearningNode() {
- return this.learningNode;
- }
+ private RulePassiveRegressionNode learningNode;
- @Override
- public void setLearningNode(RuleRegressionNode learningNode) {
- this.learningNode = (RulePassiveRegressionNode) learningNode;
- }
-
- /*
- * MOA GUI
- */
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ /*
+ * Constructor to turn an ActiveRule into a PassiveRule
+ */
+ public PassiveRule(ActiveRule rule) {
+ this.nodeList = new LinkedList<>();
+ for (RuleSplitNode node : rule.nodeList) {
+ this.nodeList.add(node.getACopy());
+ }
+
+ this.learningNode = new RulePassiveRegressionNode(rule.getLearningNode());
+ this.ruleNumberID = rule.ruleNumberID;
+ }
+
+ @Override
+ public RuleRegressionNode getLearningNode() {
+ return this.learningNode;
+ }
+
+ @Override
+ public void setLearningNode(RuleRegressionNode learningNode) {
+ this.learningNode = (RulePassiveRegressionNode) learningNode;
+ }
+
+ /*
+ * MOA GUI
+ */
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Perceptron.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Perceptron.java
index 53583ed..463dcdf 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Perceptron.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Perceptron.java
@@ -33,455 +33,458 @@
import com.yahoo.labs.samoa.moa.core.Measurement;
/**
- * Prediction scheme using Perceptron:
- * Predictions are computed according to a linear function of the attributes.
+ * Prediction scheme using Perceptron: Predictions are computed according to a
+ * linear function of the attributes.
*
* @author Anh Thu Vu
- *
+ *
*/
public class Perceptron extends AbstractClassifier implements Regressor {
- private final double SD_THRESHOLD = 0.0000001; //THRESHOLD for normalizing attribute and target values
+ private final double SD_THRESHOLD = 0.0000001; // THRESHOLD for normalizing
+ // attribute and target values
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- // public FlagOption constantLearningRatioDecayOption = new FlagOption(
- // "learningRatio_Decay_set_constant", 'd',
- // "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
- //
- // public FloatOption learningRatioOption = new FloatOption(
- // "learningRatio", 'l',
- // "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.01);
- //
- // public FloatOption learningRateDecayOption = new FloatOption(
- // "learningRateDecay", 'm',
- // " Learning Rate decay to use for training the Perceptron.", 0.001);
- //
- // public FloatOption fadingFactorOption = new FloatOption(
- // "fadingFactor", 'e',
- // "Fading factor for the Perceptron accumulated error", 0.99, 0, 1);
+ // public FlagOption constantLearningRatioDecayOption = new FlagOption(
+ // "learningRatio_Decay_set_constant", 'd',
+ // "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
+ //
+ // public FloatOption learningRatioOption = new FloatOption(
+ // "learningRatio", 'l',
+ // "Constante Learning Ratio to use for training the Perceptrons in the leaves.",
+ // 0.01);
+ //
+ // public FloatOption learningRateDecayOption = new FloatOption(
+ // "learningRateDecay", 'm',
+ // " Learning Rate decay to use for training the Perceptron.", 0.001);
+ //
+ // public FloatOption fadingFactorOption = new FloatOption(
+ // "fadingFactor", 'e',
+ // "Fading factor for the Perceptron accumulated error", 0.99, 0, 1);
- protected boolean constantLearningRatioDecay;
- protected double originalLearningRatio;
+ protected boolean constantLearningRatioDecay;
+ protected double originalLearningRatio;
- private double nError;
- protected double fadingFactor = 0.99;
- private double learningRatio;
- protected double learningRateDecay = 0.001;
+ private double nError;
+ protected double fadingFactor = 0.99;
+ private double learningRatio;
+ protected double learningRateDecay = 0.001;
- // The Perception weights
- protected double[] weightAttribute;
+ // The Perception weights
+ protected double[] weightAttribute;
- // Statistics used for error calculations
- public DoubleVector perceptronattributeStatistics = new DoubleVector();
- public DoubleVector squaredperceptronattributeStatistics = new DoubleVector();
+ // Statistics used for error calculations
+ public DoubleVector perceptronattributeStatistics = new DoubleVector();
+ public DoubleVector squaredperceptronattributeStatistics = new DoubleVector();
- // The number of instances contributing to this model
- protected int perceptronInstancesSeen;
- protected int perceptronYSeen;
+ // The number of instances contributing to this model
+ protected int perceptronInstancesSeen;
+ protected int perceptronYSeen;
- protected double accumulatedError;
+ protected double accumulatedError;
- // If the model (weights) should be reset or not
- protected boolean initialisePerceptron;
+ // If the model (weights) should be reset or not
+ protected boolean initialisePerceptron;
- protected double perceptronsumY;
- protected double squaredperceptronsumY;
+ protected double perceptronsumY;
+ protected double squaredperceptronsumY;
+ public Perceptron() {
+ this.initialisePerceptron = true;
+ }
- public Perceptron() {
- this.initialisePerceptron = true;
- }
+ /*
+ * Perceptron
+ */
+ public Perceptron(Perceptron p) {
+ this(p, false);
+ }
- /*
- * Perceptron
- */
- public Perceptron(Perceptron p) {
- this(p,false);
- }
-
- public Perceptron(Perceptron p, boolean copyAccumulatedError) {
- super();
- // this.constantLearningRatioDecayOption = p.constantLearningRatioDecayOption;
- // this.learningRatioOption = p.learningRatioOption;
- // this.learningRateDecayOption=p.learningRateDecayOption;
- // this.fadingFactorOption = p.fadingFactorOption;
- this.constantLearningRatioDecay = p.constantLearningRatioDecay;
- this.originalLearningRatio = p.originalLearningRatio;
- if (copyAccumulatedError)
- this.accumulatedError = p.accumulatedError;
- this.nError = p.nError;
- this.fadingFactor = p.fadingFactor;
- this.learningRatio = p.learningRatio;
- this.learningRateDecay = p.learningRateDecay;
- if (p.weightAttribute!=null)
- this.weightAttribute = p.weightAttribute.clone();
+ public Perceptron(Perceptron p, boolean copyAccumulatedError) {
+ super();
+ // this.constantLearningRatioDecayOption =
+ // p.constantLearningRatioDecayOption;
+ // this.learningRatioOption = p.learningRatioOption;
+ // this.learningRateDecayOption=p.learningRateDecayOption;
+ // this.fadingFactorOption = p.fadingFactorOption;
+ this.constantLearningRatioDecay = p.constantLearningRatioDecay;
+ this.originalLearningRatio = p.originalLearningRatio;
+ if (copyAccumulatedError)
+ this.accumulatedError = p.accumulatedError;
+ this.nError = p.nError;
+ this.fadingFactor = p.fadingFactor;
+ this.learningRatio = p.learningRatio;
+ this.learningRateDecay = p.learningRateDecay;
+ if (p.weightAttribute != null)
+ this.weightAttribute = p.weightAttribute.clone();
- this.perceptronattributeStatistics = new DoubleVector(p.perceptronattributeStatistics);
- this.squaredperceptronattributeStatistics = new DoubleVector(p.squaredperceptronattributeStatistics);
- this.perceptronInstancesSeen = p.perceptronInstancesSeen;
+ this.perceptronattributeStatistics = new DoubleVector(p.perceptronattributeStatistics);
+ this.squaredperceptronattributeStatistics = new DoubleVector(p.squaredperceptronattributeStatistics);
+ this.perceptronInstancesSeen = p.perceptronInstancesSeen;
- this.initialisePerceptron = p.initialisePerceptron;
- this.perceptronsumY = p.perceptronsumY;
- this.squaredperceptronsumY = p.squaredperceptronsumY;
- this.perceptronYSeen=p.perceptronYSeen;
- }
-
- public Perceptron(PerceptronData p) {
- super();
- this.constantLearningRatioDecay = p.constantLearningRatioDecay;
- this.originalLearningRatio = p.originalLearningRatio;
- this.nError = p.nError;
- this.fadingFactor = p.fadingFactor;
- this.learningRatio = p.learningRatio;
- this.learningRateDecay = p.learningRateDecay;
- if (p.weightAttribute!=null)
- this.weightAttribute = p.weightAttribute.clone();
+ this.initialisePerceptron = p.initialisePerceptron;
+ this.perceptronsumY = p.perceptronsumY;
+ this.squaredperceptronsumY = p.squaredperceptronsumY;
+ this.perceptronYSeen = p.perceptronYSeen;
+ }
- this.perceptronattributeStatistics = new DoubleVector(p.perceptronattributeStatistics);
- this.squaredperceptronattributeStatistics = new DoubleVector(p.squaredperceptronattributeStatistics);
- this.perceptronInstancesSeen = p.perceptronInstancesSeen;
+ public Perceptron(PerceptronData p) {
+ super();
+ this.constantLearningRatioDecay = p.constantLearningRatioDecay;
+ this.originalLearningRatio = p.originalLearningRatio;
+ this.nError = p.nError;
+ this.fadingFactor = p.fadingFactor;
+ this.learningRatio = p.learningRatio;
+ this.learningRateDecay = p.learningRateDecay;
+ if (p.weightAttribute != null)
+ this.weightAttribute = p.weightAttribute.clone();
- this.initialisePerceptron = p.initialisePerceptron;
- this.perceptronsumY = p.perceptronsumY;
- this.squaredperceptronsumY = p.squaredperceptronsumY;
- this.perceptronYSeen=p.perceptronYSeen;
- this.accumulatedError = p.accumulatedError;
- }
+ this.perceptronattributeStatistics = new DoubleVector(p.perceptronattributeStatistics);
+ this.squaredperceptronattributeStatistics = new DoubleVector(p.squaredperceptronattributeStatistics);
+ this.perceptronInstancesSeen = p.perceptronInstancesSeen;
- // private void printPerceptron() {
- // System.out.println("Learning Ratio:"+this.learningRatio+" ("+this.originalLearningRatio+")");
- // System.out.println("Constant Learning Ratio Decay:"+this.constantLearningRatioDecay+" ("+this.learningRateDecay+")");
- // System.out.println("Error:"+this.accumulatedError+"/"+this.nError);
- // System.out.println("Fading factor:"+this.fadingFactor);
- // System.out.println("Perceptron Y:"+this.perceptronsumY+"/"+this.squaredperceptronsumY+"/"+this.perceptronYSeen);
- // }
+ this.initialisePerceptron = p.initialisePerceptron;
+ this.perceptronsumY = p.perceptronsumY;
+ this.squaredperceptronsumY = p.squaredperceptronsumY;
+ this.perceptronYSeen = p.perceptronYSeen;
+ this.accumulatedError = p.accumulatedError;
+ }
- /*
- * Weights
- */
- public void setWeights(double[] w) {
- this.weightAttribute = w;
- }
+ // private void printPerceptron() {
+ // System.out.println("Learning Ratio:"+this.learningRatio+" ("+this.originalLearningRatio+")");
+ // System.out.println("Constant Learning Ratio Decay:"+this.constantLearningRatioDecay+" ("+this.learningRateDecay+")");
+ // System.out.println("Error:"+this.accumulatedError+"/"+this.nError);
+ // System.out.println("Fading factor:"+this.fadingFactor);
+ // System.out.println("Perceptron Y:"+this.perceptronsumY+"/"+this.squaredperceptronsumY+"/"+this.perceptronYSeen);
+ // }
- public double[] getWeights() {
- return this.weightAttribute;
- }
+ /*
+ * Weights
+ */
+ public void setWeights(double[] w) {
+ this.weightAttribute = w;
+ }
- /*
- * No. of instances seen
- */
- public int getInstancesSeen() {
- return perceptronInstancesSeen;
- }
+ public double[] getWeights() {
+ return this.weightAttribute;
+ }
- public void setInstancesSeen(int pInstancesSeen) {
- this.perceptronInstancesSeen = pInstancesSeen;
- }
+ /*
+ * No. of instances seen
+ */
+ public int getInstancesSeen() {
+ return perceptronInstancesSeen;
+ }
- /**
- * A method to reset the model
- */
- public void resetLearningImpl() {
- this.initialisePerceptron = true;
- this.reset();
- }
+ public void setInstancesSeen(int pInstancesSeen) {
+ this.perceptronInstancesSeen = pInstancesSeen;
+ }
- public void reset(){
- this.nError=0.0;
- this.accumulatedError = 0.0;
- this.perceptronInstancesSeen = 0;
- this.perceptronattributeStatistics = new DoubleVector();
- this.squaredperceptronattributeStatistics = new DoubleVector();
- this.perceptronsumY = 0.0;
- this.squaredperceptronsumY = 0.0;
- this.perceptronYSeen=0;
- }
+ /**
+ * A method to reset the model
+ */
+ public void resetLearningImpl() {
+ this.initialisePerceptron = true;
+ this.reset();
+ }
- public void resetError(){
- this.nError=0.0;
- this.accumulatedError = 0.0;
- }
+ public void reset() {
+ this.nError = 0.0;
+ this.accumulatedError = 0.0;
+ this.perceptronInstancesSeen = 0;
+ this.perceptronattributeStatistics = new DoubleVector();
+ this.squaredperceptronattributeStatistics = new DoubleVector();
+ this.perceptronsumY = 0.0;
+ this.squaredperceptronsumY = 0.0;
+ this.perceptronYSeen = 0;
+ }
- /**
- * Update the model using the provided instance
- */
- public void trainOnInstanceImpl(Instance inst) {
- accumulatedError= Math.abs(this.prediction(inst)-inst.classValue()) + fadingFactor*accumulatedError;
- nError=1+fadingFactor*nError;
- // Initialise Perceptron if necessary
- if (this.initialisePerceptron) {
- //this.fadingFactor=this.fadingFactorOption.getValue();
- //this.classifierRandom.setSeed(randomSeedOption.getValue());
- this.classifierRandom.setSeed(randomSeed);
- this.initialisePerceptron = false; // not in resetLearningImpl() because it needs Instance!
- this.weightAttribute = new double[inst.numAttributes()];
- for (int j = 0; j < inst.numAttributes(); j++) {
- weightAttribute[j] = 2 * this.classifierRandom.nextDouble() - 1;
- }
- // Update Learning Rate
- learningRatio = originalLearningRatio;
- //this.learningRateDecay = learningRateDecayOption.getValue();
+ public void resetError() {
+ this.nError = 0.0;
+ this.accumulatedError = 0.0;
+ }
- }
+ /**
+ * Update the model using the provided instance
+ */
+ public void trainOnInstanceImpl(Instance inst) {
+ accumulatedError = Math.abs(this.prediction(inst) - inst.classValue()) + fadingFactor * accumulatedError;
+ nError = 1 + fadingFactor * nError;
+ // Initialise Perceptron if necessary
+ if (this.initialisePerceptron) {
+ // this.fadingFactor=this.fadingFactorOption.getValue();
+ // this.classifierRandom.setSeed(randomSeedOption.getValue());
+ this.classifierRandom.setSeed(randomSeed);
+ this.initialisePerceptron = false; // not in resetLearningImpl() because
+ // it needs Instance!
+ this.weightAttribute = new double[inst.numAttributes()];
+ for (int j = 0; j < inst.numAttributes(); j++) {
+ weightAttribute[j] = 2 * this.classifierRandom.nextDouble() - 1;
+ }
+ // Update Learning Rate
+ learningRatio = originalLearningRatio;
+ // this.learningRateDecay = learningRateDecayOption.getValue();
- // Update attribute statistics
- this.perceptronInstancesSeen++;
- this.perceptronYSeen++;
+ }
+ // Update attribute statistics
+ this.perceptronInstancesSeen++;
+ this.perceptronYSeen++;
- for(int j = 0; j < inst.numAttributes() -1; j++)
- {
- perceptronattributeStatistics.addToValue(j, inst.value(j));
- squaredperceptronattributeStatistics.addToValue(j, inst.value(j)*inst.value(j));
- }
- this.perceptronsumY += inst.classValue();
- this.squaredperceptronsumY += inst.classValue() * inst.classValue();
+ for (int j = 0; j < inst.numAttributes() - 1; j++)
+ {
+ perceptronattributeStatistics.addToValue(j, inst.value(j));
+ squaredperceptronattributeStatistics.addToValue(j, inst.value(j) * inst.value(j));
+ }
+ this.perceptronsumY += inst.classValue();
+ this.squaredperceptronsumY += inst.classValue() * inst.classValue();
- if(!constantLearningRatioDecay){
- learningRatio = originalLearningRatio / (1+ perceptronInstancesSeen*learningRateDecay);
- }
+ if (!constantLearningRatioDecay) {
+ learningRatio = originalLearningRatio / (1 + perceptronInstancesSeen * learningRateDecay);
+ }
- this.updateWeights(inst,learningRatio);
- //this.printPerceptron();
- }
+ this.updateWeights(inst, learningRatio);
+ // this.printPerceptron();
+ }
- /**
- * Output the prediction made by this perceptron on the given instance
- */
- private double prediction(Instance inst)
- {
- double[] normalizedInstance = normalizedInstance(inst);
- double normalizedPrediction = prediction(normalizedInstance);
- return denormalizedPrediction(normalizedPrediction);
- }
+ /**
+ * Output the prediction made by this perceptron on the given instance
+ */
+ private double prediction(Instance inst)
+ {
+ double[] normalizedInstance = normalizedInstance(inst);
+ double normalizedPrediction = prediction(normalizedInstance);
+ return denormalizedPrediction(normalizedPrediction);
+ }
- public double normalizedPrediction(Instance inst)
- {
- double[] normalizedInstance = normalizedInstance(inst);
- return prediction(normalizedInstance);
- }
+ public double normalizedPrediction(Instance inst)
+ {
+ double[] normalizedInstance = normalizedInstance(inst);
+ return prediction(normalizedInstance);
+ }
- private double denormalizedPrediction(double normalizedPrediction) {
- if (!this.initialisePerceptron){
- double meanY = perceptronsumY / perceptronYSeen;
- double sdY = computeSD(squaredperceptronsumY, perceptronsumY, perceptronYSeen);
- if (sdY > SD_THRESHOLD)
- return normalizedPrediction * sdY + meanY;
- else
- return normalizedPrediction + meanY;
- }
- else
- return normalizedPrediction; //Perceptron may have been "reseted". Use old weights to predict
+ private double denormalizedPrediction(double normalizedPrediction) {
+ if (!this.initialisePerceptron) {
+ double meanY = perceptronsumY / perceptronYSeen;
+ double sdY = computeSD(squaredperceptronsumY, perceptronsumY, perceptronYSeen);
+ if (sdY > SD_THRESHOLD)
+ return normalizedPrediction * sdY + meanY;
+ else
+ return normalizedPrediction + meanY;
+ }
+ else
+ return normalizedPrediction; // Perceptron may have been "reseted". Use
+ // old weights to predict
- }
+ }
- public double prediction(double[] instanceValues)
- {
- double prediction = 0.0;
- if(!this.initialisePerceptron)
- {
- for (int j = 0; j < instanceValues.length - 1; j++) {
- prediction += this.weightAttribute[j] * instanceValues[j];
- }
- prediction += this.weightAttribute[instanceValues.length - 1];
- }
- return prediction;
- }
+ public double prediction(double[] instanceValues)
+ {
+ double prediction = 0.0;
+ if (!this.initialisePerceptron)
+ {
+ for (int j = 0; j < instanceValues.length - 1; j++) {
+ prediction += this.weightAttribute[j] * instanceValues[j];
+ }
+ prediction += this.weightAttribute[instanceValues.length - 1];
+ }
+ return prediction;
+ }
- public double[] normalizedInstance(Instance inst){
- // Normalize Instance
- double[] normalizedInstance = new double[inst.numAttributes()];
- for(int j = 0; j < inst.numAttributes() -1; j++) {
- int instAttIndex = modelAttIndexToInstanceAttIndex(j);
- double mean = perceptronattributeStatistics.getValue(j) / perceptronYSeen;
- double sd = computeSD(squaredperceptronattributeStatistics.getValue(j), perceptronattributeStatistics.getValue(j), perceptronYSeen);
- if (sd > SD_THRESHOLD)
- normalizedInstance[j] = (inst.value(instAttIndex) - mean)/ sd;
- else
- normalizedInstance[j] = inst.value(instAttIndex) - mean;
- }
- return normalizedInstance;
- }
+ public double[] normalizedInstance(Instance inst) {
+ // Normalize Instance
+ double[] normalizedInstance = new double[inst.numAttributes()];
+ for (int j = 0; j < inst.numAttributes() - 1; j++) {
+ int instAttIndex = modelAttIndexToInstanceAttIndex(j);
+ double mean = perceptronattributeStatistics.getValue(j) / perceptronYSeen;
+ double sd = computeSD(squaredperceptronattributeStatistics.getValue(j),
+ perceptronattributeStatistics.getValue(j), perceptronYSeen);
+ if (sd > SD_THRESHOLD)
+ normalizedInstance[j] = (inst.value(instAttIndex) - mean) / sd;
+ else
+ normalizedInstance[j] = inst.value(instAttIndex) - mean;
+ }
+ return normalizedInstance;
+ }
- public double computeSD(double squaredVal, double val, int size) {
- if (size > 1) {
- return Math.sqrt((squaredVal - ((val * val) / size)) / (size - 1.0));
- }
- return 0.0;
- }
+ public double computeSD(double squaredVal, double val, int size) {
+ if (size > 1) {
+ return Math.sqrt((squaredVal - ((val * val) / size)) / (size - 1.0));
+ }
+ return 0.0;
+ }
- public double updateWeights(Instance inst, double learningRatio ){
- // Normalize Instance
- double[] normalizedInstance = normalizedInstance(inst);
- // Compute the Normalized Prediction of Perceptron
- double normalizedPredict= prediction(normalizedInstance);
- double normalizedY = normalizeActualClassValue(inst);
- double sumWeights = 0.0;
- double delta = normalizedY - normalizedPredict;
+ public double updateWeights(Instance inst, double learningRatio) {
+ // Normalize Instance
+ double[] normalizedInstance = normalizedInstance(inst);
+ // Compute the Normalized Prediction of Perceptron
+ double normalizedPredict = prediction(normalizedInstance);
+ double normalizedY = normalizeActualClassValue(inst);
+ double sumWeights = 0.0;
+ double delta = normalizedY - normalizedPredict;
- for (int j = 0; j < inst.numAttributes() - 1; j++) {
- int instAttIndex = modelAttIndexToInstanceAttIndex(j);
- if(inst.attribute(instAttIndex).isNumeric()) {
- this.weightAttribute[j] += learningRatio * delta * normalizedInstance[j];
- sumWeights += Math.abs(this.weightAttribute[j]);
- }
- }
- this.weightAttribute[inst.numAttributes() - 1] += learningRatio * delta;
- sumWeights += Math.abs(this.weightAttribute[inst.numAttributes() - 1]);
- if (sumWeights > inst.numAttributes()) { // Lasso regression
- for (int j = 0; j < inst.numAttributes() - 1; j++) {
- int instAttIndex = modelAttIndexToInstanceAttIndex(j);
- if(inst.attribute(instAttIndex).isNumeric()) {
- this.weightAttribute[j] = this.weightAttribute[j] / sumWeights;
- }
- }
- this.weightAttribute[inst.numAttributes() - 1] = this.weightAttribute[inst.numAttributes() - 1] / sumWeights;
- }
+ for (int j = 0; j < inst.numAttributes() - 1; j++) {
+ int instAttIndex = modelAttIndexToInstanceAttIndex(j);
+ if (inst.attribute(instAttIndex).isNumeric()) {
+ this.weightAttribute[j] += learningRatio * delta * normalizedInstance[j];
+ sumWeights += Math.abs(this.weightAttribute[j]);
+ }
+ }
+ this.weightAttribute[inst.numAttributes() - 1] += learningRatio * delta;
+ sumWeights += Math.abs(this.weightAttribute[inst.numAttributes() - 1]);
+ if (sumWeights > inst.numAttributes()) { // Lasso regression
+ for (int j = 0; j < inst.numAttributes() - 1; j++) {
+ int instAttIndex = modelAttIndexToInstanceAttIndex(j);
+ if (inst.attribute(instAttIndex).isNumeric()) {
+ this.weightAttribute[j] = this.weightAttribute[j] / sumWeights;
+ }
+ }
+ this.weightAttribute[inst.numAttributes() - 1] = this.weightAttribute[inst.numAttributes() - 1] / sumWeights;
+ }
- return denormalizedPrediction(normalizedPredict);
- }
+ return denormalizedPrediction(normalizedPredict);
+ }
- public void normalizeWeights(){
- double sumWeights = 0.0;
+ public void normalizeWeights() {
+ double sumWeights = 0.0;
- for (double aWeightAttribute : this.weightAttribute) {
- sumWeights += Math.abs(aWeightAttribute);
- }
- for (int j = 0; j < this.weightAttribute.length; j++) {
- this.weightAttribute[j] = this.weightAttribute[j] / sumWeights;
- }
- }
+ for (double aWeightAttribute : this.weightAttribute) {
+ sumWeights += Math.abs(aWeightAttribute);
+ }
+ for (int j = 0; j < this.weightAttribute.length; j++) {
+ this.weightAttribute[j] = this.weightAttribute[j] / sumWeights;
+ }
+ }
- private double normalizeActualClassValue(Instance inst) {
- double meanY = perceptronsumY / perceptronYSeen;
- double sdY = computeSD(squaredperceptronsumY, perceptronsumY, perceptronYSeen);
+ private double normalizeActualClassValue(Instance inst) {
+ double meanY = perceptronsumY / perceptronYSeen;
+ double sdY = computeSD(squaredperceptronsumY, perceptronsumY, perceptronYSeen);
- double normalizedY;
- if (sdY > SD_THRESHOLD){
- normalizedY = (inst.classValue() - meanY) / sdY;
- }else{
- normalizedY = inst.classValue() - meanY;
- }
- return normalizedY;
- }
+ double normalizedY;
+ if (sdY > SD_THRESHOLD) {
+ normalizedY = (inst.classValue() - meanY) / sdY;
+ } else {
+ normalizedY = inst.classValue() - meanY;
+ }
+ return normalizedY;
+ }
- @Override
- public boolean isRandomizable() {
- return true;
- }
+ @Override
+ public boolean isRandomizable() {
+ return true;
+ }
- @Override
- public double[] getVotesForInstance(Instance inst) {
- return new double[]{this.prediction(inst)};
- }
+ @Override
+ public double[] getVotesForInstance(Instance inst) {
+ return new double[] { this.prediction(inst) };
+ }
- @Override
- protected Measurement[] getModelMeasurementsImpl() {
- return null;
- }
+ @Override
+ protected Measurement[] getModelMeasurementsImpl() {
+ return null;
+ }
- @Override
- public void getModelDescription(StringBuilder out, int indent) {
- if(this.weightAttribute!=null){
- for(int i=0; i< this.weightAttribute.length-1; ++i)
- {
- if(this.weightAttribute[i]>=0 && i>0)
- out.append(" +" + Math.round(this.weightAttribute[i]*1000)/1000.0 + " X" + i );
- else
- out.append(" " + Math.round(this.weightAttribute[i]*1000)/1000.0 + " X" + i );
- }
- if(this.weightAttribute[this.weightAttribute.length-1]>=0 )
- out.append(" +" + Math.round(this.weightAttribute[this.weightAttribute.length-1]*1000)/1000.0);
- else
- out.append(" " + Math.round(this.weightAttribute[this.weightAttribute.length-1]*1000)/1000.0);
- }
- }
+ @Override
+ public void getModelDescription(StringBuilder out, int indent) {
+ if (this.weightAttribute != null) {
+ for (int i = 0; i < this.weightAttribute.length - 1; ++i)
+ {
+ if (this.weightAttribute[i] >= 0 && i > 0)
+ out.append(" +" + Math.round(this.weightAttribute[i] * 1000) / 1000.0 + " X" + i);
+ else
+ out.append(" " + Math.round(this.weightAttribute[i] * 1000) / 1000.0 + " X" + i);
+ }
+ if (this.weightAttribute[this.weightAttribute.length - 1] >= 0)
+ out.append(" +" + Math.round(this.weightAttribute[this.weightAttribute.length - 1] * 1000) / 1000.0);
+ else
+ out.append(" " + Math.round(this.weightAttribute[this.weightAttribute.length - 1] * 1000) / 1000.0);
+ }
+ }
- public void setLearningRatio(double learningRatio) {
- this.learningRatio=learningRatio;
+ public void setLearningRatio(double learningRatio) {
+ this.learningRatio = learningRatio;
- }
+ }
- public double getCurrentError()
- {
- if (nError>0)
- return accumulatedError/nError;
- else
- return Double.MAX_VALUE;
- }
-
- public static class PerceptronData implements Serializable {
- /**
+ public double getCurrentError()
+ {
+ if (nError > 0)
+ return accumulatedError / nError;
+ else
+ return Double.MAX_VALUE;
+ }
+
+ public static class PerceptronData implements Serializable {
+ /**
*
*/
- private static final long serialVersionUID = 6727623208744105082L;
-
- private boolean constantLearningRatioDecay;
- // If the model (weights) should be reset or not
- private boolean initialisePerceptron;
-
- private double nError;
- private double fadingFactor;
- private double originalLearningRatio;
- private double learningRatio;
- private double learningRateDecay;
- private double accumulatedError;
- private double perceptronsumY;
- private double squaredperceptronsumY;
+ private static final long serialVersionUID = 6727623208744105082L;
- // The Perception weights
- private double[] weightAttribute;
+ private boolean constantLearningRatioDecay;
+ // If the model (weights) should be reset or not
+ private boolean initialisePerceptron;
- // Statistics used for error calculations
- private DoubleVector perceptronattributeStatistics;
- private DoubleVector squaredperceptronattributeStatistics;
+ private double nError;
+ private double fadingFactor;
+ private double originalLearningRatio;
+ private double learningRatio;
+ private double learningRateDecay;
+ private double accumulatedError;
+ private double perceptronsumY;
+ private double squaredperceptronsumY;
- // The number of instances contributing to this model
- private int perceptronInstancesSeen;
- private int perceptronYSeen;
+ // The Perception weights
+ private double[] weightAttribute;
- public PerceptronData() {
-
- }
-
- public PerceptronData(Perceptron p) {
- this.constantLearningRatioDecay = p.constantLearningRatioDecay;
- this.initialisePerceptron = p.initialisePerceptron;
- this.nError = p.nError;
- this.fadingFactor = p.fadingFactor;
- this.originalLearningRatio = p.originalLearningRatio;
- this.learningRatio = p.learningRatio;
- this.learningRateDecay = p.learningRateDecay;
- this.accumulatedError = p.accumulatedError;
- this.perceptronsumY = p.perceptronsumY;
- this.squaredperceptronsumY = p.squaredperceptronsumY;
- this.weightAttribute = p.weightAttribute;
- this.perceptronattributeStatistics = p.perceptronattributeStatistics;
- this.squaredperceptronattributeStatistics = p.squaredperceptronattributeStatistics;
- this.perceptronInstancesSeen = p.perceptronInstancesSeen;
- this.perceptronYSeen = p.perceptronYSeen;
- }
-
- public Perceptron build() {
- return new Perceptron(this);
- }
-
- }
-
-
- public static final class PerceptronSerializer extends Serializer<Perceptron>{
+ // Statistics used for error calculations
+ private DoubleVector perceptronattributeStatistics;
+ private DoubleVector squaredperceptronattributeStatistics;
- @Override
- public void write(Kryo kryo, Output output, Perceptron p) {
- kryo.writeObjectOrNull(output, new PerceptronData(p), PerceptronData.class);
- }
+ // The number of instances contributing to this model
+ private int perceptronInstancesSeen;
+ private int perceptronYSeen;
- @Override
- public Perceptron read(Kryo kryo, Input input, Class<Perceptron> type) {
- PerceptronData perceptronData = kryo.readObjectOrNull(input, PerceptronData.class);
- return perceptronData.build();
- }
- }
+ public PerceptronData() {
+
+ }
+
+ public PerceptronData(Perceptron p) {
+ this.constantLearningRatioDecay = p.constantLearningRatioDecay;
+ this.initialisePerceptron = p.initialisePerceptron;
+ this.nError = p.nError;
+ this.fadingFactor = p.fadingFactor;
+ this.originalLearningRatio = p.originalLearningRatio;
+ this.learningRatio = p.learningRatio;
+ this.learningRateDecay = p.learningRateDecay;
+ this.accumulatedError = p.accumulatedError;
+ this.perceptronsumY = p.perceptronsumY;
+ this.squaredperceptronsumY = p.squaredperceptronsumY;
+ this.weightAttribute = p.weightAttribute;
+ this.perceptronattributeStatistics = p.perceptronattributeStatistics;
+ this.squaredperceptronattributeStatistics = p.squaredperceptronattributeStatistics;
+ this.perceptronInstancesSeen = p.perceptronInstancesSeen;
+ this.perceptronYSeen = p.perceptronYSeen;
+ }
+
+ public Perceptron build() {
+ return new Perceptron(this);
+ }
+
+ }
+
+ public static final class PerceptronSerializer extends Serializer<Perceptron> {
+
+ @Override
+ public void write(Kryo kryo, Output output, Perceptron p) {
+ kryo.writeObjectOrNull(output, new PerceptronData(p), PerceptronData.class);
+ }
+
+ @Override
+ public Perceptron read(Kryo kryo, Input input, Class<Perceptron> type) {
+ PerceptronData perceptronData = kryo.readObjectOrNull(input, PerceptronData.class);
+ return perceptronData.build();
+ }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Rule.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Rule.java
index b85cf10..0f58559 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Rule.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Rule.java
@@ -28,84 +28,95 @@
import com.yahoo.labs.samoa.moa.classifiers.rules.core.conditionaltests.NumericAttributeBinaryRulePredicate;
/**
- * The base class for "rule".
- * Represents the most basic rule with and ID and a list of features (nodeList).
+ * The base class for "rule". Represents the most basic rule with and ID and a
+ * list of features (nodeList).
*
* @author Anh Thu Vu
- *
+ *
*/
public abstract class Rule extends AbstractMOAObject {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected int ruleNumberID;
-
- protected List<RuleSplitNode> nodeList;
-
- /*
- * Constructor
- */
- public Rule() {
- this.nodeList = new LinkedList<RuleSplitNode>();
- }
-
- /*
- * Rule ID
- */
- public int getRuleNumberID() {
- return ruleNumberID;
- }
+ protected int ruleNumberID;
- public void setRuleNumberID(int ruleNumberID) {
- this.ruleNumberID = ruleNumberID;
- }
-
- /*
- * RuleSplitNode list
- */
- public List<RuleSplitNode> getNodeList() {
- return nodeList;
- }
+ protected List<RuleSplitNode> nodeList;
- public void setNodeList(List<RuleSplitNode> nodeList) {
- this.nodeList = nodeList;
- }
-
- /*
- * Covering
- */
- public boolean isCovering(Instance inst) {
- boolean isCovering = true;
- for (RuleSplitNode node : nodeList) {
- if (node.evaluate(inst) == false) {
- isCovering = false;
- break;
- }
- }
- return isCovering;
- }
-
- /*
- * Add RuleSplitNode
- */
- public boolean nodeListAdd(RuleSplitNode ruleSplitNode) {
- //Check that the node is not already in the list
- boolean isIncludedInNodeList = false;
- boolean isUpdated=false;
- for (RuleSplitNode node : nodeList) {
- NumericAttributeBinaryRulePredicate nodeTest = (NumericAttributeBinaryRulePredicate) node.getSplitTest();
- NumericAttributeBinaryRulePredicate ruleSplitNodeTest = (NumericAttributeBinaryRulePredicate) ruleSplitNode.getSplitTest();
- if (nodeTest.isUsingSameAttribute(ruleSplitNodeTest)) {
- isIncludedInNodeList = true;
- if (nodeTest.isIncludedInRuleNode(ruleSplitNodeTest) == true) { //remove this line to keep the most recent attribute value
- //replace the value
- nodeTest.setAttributeValue(ruleSplitNodeTest);
- isUpdated=true; //if is updated (i.e. an expansion happened) a new learning node should be created
- }
- }
- }
- if (isIncludedInNodeList == false) {
- this.nodeList.add(ruleSplitNode);
- }
- return (!isIncludedInNodeList || isUpdated);
- }
+ /*
+ * Constructor
+ */
+ public Rule() {
+ this.nodeList = new LinkedList<RuleSplitNode>();
+ }
+
+ /*
+ * Rule ID
+ */
+ public int getRuleNumberID() {
+ return ruleNumberID;
+ }
+
+ public void setRuleNumberID(int ruleNumberID) {
+ this.ruleNumberID = ruleNumberID;
+ }
+
+ /*
+ * RuleSplitNode list
+ */
+ public List<RuleSplitNode> getNodeList() {
+ return nodeList;
+ }
+
+ public void setNodeList(List<RuleSplitNode> nodeList) {
+ this.nodeList = nodeList;
+ }
+
+ /*
+ * Covering
+ */
+ public boolean isCovering(Instance inst) {
+ boolean isCovering = true;
+ for (RuleSplitNode node : nodeList) {
+ if (node.evaluate(inst) == false) {
+ isCovering = false;
+ break;
+ }
+ }
+ return isCovering;
+ }
+
+ /*
+ * Add RuleSplitNode
+ */
+ public boolean nodeListAdd(RuleSplitNode ruleSplitNode) {
+ // Check that the node is not already in the list
+ boolean isIncludedInNodeList = false;
+ boolean isUpdated = false;
+ for (RuleSplitNode node : nodeList) {
+ NumericAttributeBinaryRulePredicate nodeTest = (NumericAttributeBinaryRulePredicate) node.getSplitTest();
+ NumericAttributeBinaryRulePredicate ruleSplitNodeTest = (NumericAttributeBinaryRulePredicate) ruleSplitNode
+ .getSplitTest();
+ if (nodeTest.isUsingSameAttribute(ruleSplitNodeTest)) {
+ isIncludedInNodeList = true;
+ if (nodeTest.isIncludedInRuleNode(ruleSplitNodeTest) == true) { // remove
+ // this
+ // line
+ // to
+ // keep
+ // the
+ // most
+ // recent
+ // attribute
+ // value
+ // replace the value
+ nodeTest.setAttributeValue(ruleSplitNodeTest);
+ isUpdated = true; // if is updated (i.e. an expansion happened) a new
+ // learning node should be created
+ }
+ }
+ }
+ if (isIncludedInNodeList == false) {
+ this.nodeList.add(ruleSplitNode);
+ }
+ return (!isIncludedInNodeList || isUpdated);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveLearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveLearningNode.java
index f52ac32..513340e 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveLearningNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveLearningNode.java
@@ -21,14 +21,14 @@
*/
/**
- * Interface for Rule's LearningNode that updates both statistics
- * for expanding rule and computing predictions.
+ * Interface for Rule's LearningNode that updates both statistics for expanding
+ * rule and computing predictions.
*
* @author Anh Thu Vu
- *
+ *
*/
public interface RuleActiveLearningNode extends RulePassiveLearningNode {
- public boolean tryToExpand(double splitConfidence, double tieThreshold);
-
+ public boolean tryToExpand(double splitConfidence, double tieThreshold);
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveRegressionNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveRegressionNode.java
index 05079ed..48daf7a 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveRegressionNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveRegressionNode.java
@@ -37,282 +37,297 @@
import com.yahoo.labs.samoa.moa.classifiers.rules.driftdetection.PageHinkleyTest;
/**
- * LearningNode for regression rule that updates both statistics for
- * expanding rule and computing predictions.
+ * LearningNode for regression rule that updates both statistics for expanding
+ * rule and computing predictions.
*
* @author Anh Thu Vu
- *
+ *
*/
public class RuleActiveRegressionNode extends RuleRegressionNode implements RuleActiveLearningNode {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 519854943188168546L;
+ private static final long serialVersionUID = 519854943188168546L;
- protected int splitIndex = 0;
-
- protected PageHinkleyTest pageHinckleyTest;
- protected boolean changeDetection;
-
- protected double[] statisticsNewRuleActiveLearningNode = null;
- protected double[] statisticsBranchSplit = null;
- protected double[] statisticsOtherBranchSplit;
-
- protected AttributeSplitSuggestion bestSuggestion = null;
-
- protected AutoExpandVector<AttributeClassObserver> attributeObservers = new AutoExpandVector<>();
- private FIMTDDNumericAttributeClassLimitObserver numericObserver;
-
- /*
- * Simple setters & getters
- */
- public int getSplitIndex() {
- return splitIndex;
- }
+ protected int splitIndex = 0;
- public void setSplitIndex(int splitIndex) {
- this.splitIndex = splitIndex;
- }
-
- public double[] getStatisticsOtherBranchSplit() {
- return statisticsOtherBranchSplit;
- }
+ protected PageHinkleyTest pageHinckleyTest;
+ protected boolean changeDetection;
- public void setStatisticsOtherBranchSplit(double[] statisticsOtherBranchSplit) {
- this.statisticsOtherBranchSplit = statisticsOtherBranchSplit;
- }
+ protected double[] statisticsNewRuleActiveLearningNode = null;
+ protected double[] statisticsBranchSplit = null;
+ protected double[] statisticsOtherBranchSplit;
- public double[] getStatisticsBranchSplit() {
- return statisticsBranchSplit;
- }
+ protected AttributeSplitSuggestion bestSuggestion = null;
- public void setStatisticsBranchSplit(double[] statisticsBranchSplit) {
- this.statisticsBranchSplit = statisticsBranchSplit;
- }
+ protected AutoExpandVector<AttributeClassObserver> attributeObservers = new AutoExpandVector<>();
+ private FIMTDDNumericAttributeClassLimitObserver numericObserver;
- public double[] getStatisticsNewRuleActiveLearningNode() {
- return statisticsNewRuleActiveLearningNode;
- }
+ /*
+ * Simple setters & getters
+ */
+ public int getSplitIndex() {
+ return splitIndex;
+ }
- public void setStatisticsNewRuleActiveLearningNode(
- double[] statisticsNewRuleActiveLearningNode) {
- this.statisticsNewRuleActiveLearningNode = statisticsNewRuleActiveLearningNode;
- }
-
- public AttributeSplitSuggestion getBestSuggestion() {
- return bestSuggestion;
- }
-
- public void setBestSuggestion(AttributeSplitSuggestion bestSuggestion) {
- this.bestSuggestion = bestSuggestion;
- }
-
- /*
- * Constructor with builder
- */
- public RuleActiveRegressionNode() {
- super();
- }
- public RuleActiveRegressionNode(ActiveRule.Builder builder) {
- super(builder.statistics);
- this.changeDetection = builder.changeDetection;
- if (!builder.changeDetection) {
- this.pageHinckleyTest = new PageHinkleyFading(builder.threshold, builder.alpha);
- }
- this.predictionFunction = builder.predictionFunction;
- this.learningRatio = builder.learningRatio;
- this.ruleNumberID = builder.id;
- this.numericObserver = builder.numericObserver;
-
- this.perceptron = new Perceptron();
- this.perceptron.prepareForUse();
- this.perceptron.originalLearningRatio = builder.learningRatio;
- this.perceptron.constantLearningRatioDecay = builder.constantLearningRatioDecay;
+ public void setSplitIndex(int splitIndex) {
+ this.splitIndex = splitIndex;
+ }
+ public double[] getStatisticsOtherBranchSplit() {
+ return statisticsOtherBranchSplit;
+ }
- if(this.predictionFunction!=1)
- {
- this.targetMean = new TargetMean();
- if (builder.statistics[0]>0)
- this.targetMean.reset(builder.statistics[1]/builder.statistics[0],(long)builder.statistics[0]);
- }
- this.predictionFunction = builder.predictionFunction;
- if (builder.statistics!=null)
- this.nodeStatistics=new DoubleVector(builder.statistics);
- }
-
- /*
- * Update with input instance
- */
- public boolean updatePageHinckleyTest(double error) {
- boolean changeDetected = false;
- if (!this.changeDetection) {
- changeDetected = pageHinckleyTest.update(error);
- }
- return changeDetected;
- }
-
- public boolean updateChangeDetection(double error) {
- return !changeDetection && pageHinckleyTest.update(error);
- }
-
- @Override
- public void updateStatistics(Instance inst) {
- // Update the statistics for this node
- // number of instances passing through the node
- nodeStatistics.addToValue(0, 1);
- // sum of y values
- nodeStatistics.addToValue(1, inst.classValue());
- // sum of squared y values
- nodeStatistics.addToValue(2, inst.classValue()*inst.classValue());
-
- for (int i = 0; i < inst.numAttributes() - 1; i++) {
- int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
+ public void setStatisticsOtherBranchSplit(double[] statisticsOtherBranchSplit) {
+ this.statisticsOtherBranchSplit = statisticsOtherBranchSplit;
+ }
- AttributeClassObserver obs = this.attributeObservers.get(i);
- if (obs == null) {
- // At this stage all nominal attributes are ignored
- if (inst.attribute(instAttIndex).isNumeric()) //instAttIndex
- {
- obs = newNumericClassObserver();
- this.attributeObservers.set(i, obs);
- }
- }
- if (obs != null) {
- ((FIMTDDNumericAttributeClassObserver) obs).observeAttributeClass(inst.value(instAttIndex), inst.classValue(), inst.weight());
- }
- }
-
- this.perceptron.trainOnInstance(inst);
- if (this.predictionFunction != 1) { //Train target mean if prediction function is not Perceptron
- this.targetMean.trainOnInstance(inst);
- }
- }
-
- protected AttributeClassObserver newNumericClassObserver() {
- //return new FIMTDDNumericAttributeClassObserver();
- //return new FIMTDDNumericAttributeClassLimitObserver();
- //return (AttributeClassObserver)((AttributeClassObserver)this.numericObserverOption.getPreMaterializedObject()).copy();
- FIMTDDNumericAttributeClassLimitObserver newObserver = new FIMTDDNumericAttributeClassLimitObserver();
- newObserver.setMaxNodes(numericObserver.getMaxNodes());
- return newObserver;
+ public double[] getStatisticsBranchSplit() {
+ return statisticsBranchSplit;
+ }
+
+ public void setStatisticsBranchSplit(double[] statisticsBranchSplit) {
+ this.statisticsBranchSplit = statisticsBranchSplit;
+ }
+
+ public double[] getStatisticsNewRuleActiveLearningNode() {
+ return statisticsNewRuleActiveLearningNode;
+ }
+
+ public void setStatisticsNewRuleActiveLearningNode(
+ double[] statisticsNewRuleActiveLearningNode) {
+ this.statisticsNewRuleActiveLearningNode = statisticsNewRuleActiveLearningNode;
+ }
+
+ public AttributeSplitSuggestion getBestSuggestion() {
+ return bestSuggestion;
+ }
+
+ public void setBestSuggestion(AttributeSplitSuggestion bestSuggestion) {
+ this.bestSuggestion = bestSuggestion;
+ }
+
+ /*
+ * Constructor with builder
+ */
+ public RuleActiveRegressionNode() {
+ super();
+ }
+
+ public RuleActiveRegressionNode(ActiveRule.Builder builder) {
+ super(builder.statistics);
+ this.changeDetection = builder.changeDetection;
+ if (!builder.changeDetection) {
+ this.pageHinckleyTest = new PageHinkleyFading(builder.threshold, builder.alpha);
}
-
- /*
- * Init after being split from oldLearningNode
- */
- public void initialize(RuleRegressionNode oldLearningNode) {
- if(oldLearningNode.perceptron!=null)
- {
- this.perceptron=new Perceptron(oldLearningNode.perceptron);
- this.perceptron.resetError();
- this.perceptron.setLearningRatio(oldLearningNode.learningRatio);
- }
+ this.predictionFunction = builder.predictionFunction;
+ this.learningRatio = builder.learningRatio;
+ this.ruleNumberID = builder.id;
+ this.numericObserver = builder.numericObserver;
- if(oldLearningNode.targetMean!=null)
- {
- this.targetMean= new TargetMean(oldLearningNode.targetMean);
- this.targetMean.resetError();
- }
- //reset statistics
- this.nodeStatistics.setValue(0, 0);
- this.nodeStatistics.setValue(1, 0);
- this.nodeStatistics.setValue(2, 0);
- }
+ this.perceptron = new Perceptron();
+ this.perceptron.prepareForUse();
+ this.perceptron.originalLearningRatio = builder.learningRatio;
+ this.perceptron.constantLearningRatioDecay = builder.constantLearningRatioDecay;
- /*
- * Expand
- */
- @Override
- public boolean tryToExpand(double splitConfidence, double tieThreshold) {
+ if (this.predictionFunction != 1)
+ {
+ this.targetMean = new TargetMean();
+ if (builder.statistics[0] > 0)
+ this.targetMean.reset(builder.statistics[1] / builder.statistics[0], (long) builder.statistics[0]);
+ }
+ this.predictionFunction = builder.predictionFunction;
+ if (builder.statistics != null)
+ this.nodeStatistics = new DoubleVector(builder.statistics);
+ }
- // splitConfidence. Hoeffding Bound test parameter.
- // tieThreshold. Hoeffding Bound test parameter.
- SplitCriterion splitCriterion = new SDRSplitCriterionAMRules();
- //SplitCriterion splitCriterion = new SDRSplitCriterionAMRulesNode();//JD for assessing only best branch
+ /*
+ * Update with input instance
+ */
+ public boolean updatePageHinckleyTest(double error) {
+ boolean changeDetected = false;
+ if (!this.changeDetection) {
+ changeDetected = pageHinckleyTest.update(error);
+ }
+ return changeDetected;
+ }
- // Using this criterion, find the best split per attribute and rank the results
- AttributeSplitSuggestion[] bestSplitSuggestions = this.getBestSplitSuggestions(splitCriterion);
- Arrays.sort(bestSplitSuggestions);
- // Declare a variable to determine if any of the splits should be performed
- boolean shouldSplit = false;
+ public boolean updateChangeDetection(double error) {
+ return !changeDetection && pageHinckleyTest.update(error);
+ }
- // If only one split was returned, use it
- if (bestSplitSuggestions.length < 2) {
- shouldSplit = ((bestSplitSuggestions.length > 0) && (bestSplitSuggestions[0].merit > 0));
- bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
- } // Otherwise, consider which of the splits proposed may be worth trying
- else {
- // Determine the hoeffding bound value, used to select how many instances should be used to make a test decision
- // to feel reasonably confident that the test chosen by this sample is the same as what would be chosen using infinite examples
- double hoeffdingBound = computeHoeffdingBound(1, splitConfidence, getInstancesSeen());
- // Determine the top two ranked splitting suggestions
- bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
- AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2];
-
- // If the upper bound of the sample mean for the ratio of SDR(best suggestion) to SDR(second best suggestion),
- // as determined using the hoeffding bound, is less than 1, then the true mean is also less than 1, and thus at this
- // particular moment of observation the bestSuggestion is indeed the best split option with confidence 1-delta, and
- // splitting should occur.
- // Alternatively, if two or more splits are very similar or identical in terms of their splits, then a threshold limit
- // (default 0.05) is applied to the hoeffding bound; if the hoeffding bound is smaller than this limit then the two
- // competing attributes are equally good, and the split will be made on the one with the higher SDR value.
+ @Override
+ public void updateStatistics(Instance inst) {
+ // Update the statistics for this node
+ // number of instances passing through the node
+ nodeStatistics.addToValue(0, 1);
+ // sum of y values
+ nodeStatistics.addToValue(1, inst.classValue());
+ // sum of squared y values
+ nodeStatistics.addToValue(2, inst.classValue() * inst.classValue());
- if (bestSuggestion.merit > 0) {
- if ((((secondBestSuggestion.merit / bestSuggestion.merit) + hoeffdingBound) < 1)
- || (hoeffdingBound < tieThreshold)) {
- shouldSplit = true;
- }
- }
- }
+ for (int i = 0; i < inst.numAttributes() - 1; i++) {
+ int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
- if (shouldSplit) {
- AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1];
- double minValue = Double.MAX_VALUE;
- double[] branchMerits = SDRSplitCriterionAMRules.computeBranchSplitMerits(bestSuggestion.resultingClassDistributions);
+ AttributeClassObserver obs = this.attributeObservers.get(i);
+ if (obs == null) {
+ // At this stage all nominal attributes are ignored
+ if (inst.attribute(instAttIndex).isNumeric()) // instAttIndex
+ {
+ obs = newNumericClassObserver();
+ this.attributeObservers.set(i, obs);
+ }
+ }
+ if (obs != null) {
+ ((FIMTDDNumericAttributeClassObserver) obs).observeAttributeClass(inst.value(instAttIndex), inst.classValue(),
+ inst.weight());
+ }
+ }
- for (int i = 0; i < bestSuggestion.numSplits(); i++) {
- double value = branchMerits[i];
- if (value < minValue) {
- minValue = value;
- splitIndex = i;
- statisticsNewRuleActiveLearningNode = bestSuggestion.resultingClassDistributionFromSplit(i);
- }
- }
- statisticsBranchSplit = splitDecision.resultingClassDistributionFromSplit(splitIndex);
- statisticsOtherBranchSplit = bestSuggestion.resultingClassDistributionFromSplit(splitIndex == 0 ? 1 : 0);
+ this.perceptron.trainOnInstance(inst);
+ if (this.predictionFunction != 1) { // Train target mean if prediction
+ // function is not Perceptron
+ this.targetMean.trainOnInstance(inst);
+ }
+ }
- }
- return shouldSplit;
- }
-
- public AutoExpandVector<AttributeClassObserver> getAttributeObservers() {
- return this.attributeObservers;
- }
-
- public AttributeSplitSuggestion[] getBestSplitSuggestions(SplitCriterion criterion) {
+ protected AttributeClassObserver newNumericClassObserver() {
+ // return new FIMTDDNumericAttributeClassObserver();
+ // return new FIMTDDNumericAttributeClassLimitObserver();
+ // return
+ // (AttributeClassObserver)((AttributeClassObserver)this.numericObserverOption.getPreMaterializedObject()).copy();
+ FIMTDDNumericAttributeClassLimitObserver newObserver = new FIMTDDNumericAttributeClassLimitObserver();
+ newObserver.setMaxNodes(numericObserver.getMaxNodes());
+ return newObserver;
+ }
- List<AttributeSplitSuggestion> bestSuggestions = new LinkedList<AttributeSplitSuggestion>();
+ /*
+ * Init after being split from oldLearningNode
+ */
+ public void initialize(RuleRegressionNode oldLearningNode) {
+ if (oldLearningNode.perceptron != null)
+ {
+ this.perceptron = new Perceptron(oldLearningNode.perceptron);
+ this.perceptron.resetError();
+ this.perceptron.setLearningRatio(oldLearningNode.learningRatio);
+ }
- // Set the nodeStatistics up as the preSplitDistribution, rather than the observedClassDistribution
- double[] nodeSplitDist = this.nodeStatistics.getArrayCopy();
- for (int i = 0; i < this.attributeObservers.size(); i++) {
- AttributeClassObserver obs = this.attributeObservers.get(i);
- if (obs != null) {
+ if (oldLearningNode.targetMean != null)
+ {
+ this.targetMean = new TargetMean(oldLearningNode.targetMean);
+ this.targetMean.resetError();
+ }
+ // reset statistics
+ this.nodeStatistics.setValue(0, 0);
+ this.nodeStatistics.setValue(1, 0);
+ this.nodeStatistics.setValue(2, 0);
+ }
- // AT THIS STAGE NON-NUMERIC ATTRIBUTES ARE IGNORED
- AttributeSplitSuggestion bestSuggestion = null;
- if (obs instanceof FIMTDDNumericAttributeClassObserver) {
- bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion, nodeSplitDist, i, true);
- }
+ /*
+ * Expand
+ */
+ @Override
+ public boolean tryToExpand(double splitConfidence, double tieThreshold) {
- if (bestSuggestion != null) {
- bestSuggestions.add(bestSuggestion);
- }
- }
- }
- return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]);
- }
-
+ // splitConfidence. Hoeffding Bound test parameter.
+ // tieThreshold. Hoeffding Bound test parameter.
+ SplitCriterion splitCriterion = new SDRSplitCriterionAMRules();
+ // SplitCriterion splitCriterion = new SDRSplitCriterionAMRulesNode();//JD
+ // for assessing only best branch
+
+ // Using this criterion, find the best split per attribute and rank the
+ // results
+ AttributeSplitSuggestion[] bestSplitSuggestions = this.getBestSplitSuggestions(splitCriterion);
+ Arrays.sort(bestSplitSuggestions);
+ // Declare a variable to determine if any of the splits should be performed
+ boolean shouldSplit = false;
+
+ // If only one split was returned, use it
+ if (bestSplitSuggestions.length < 2) {
+ shouldSplit = ((bestSplitSuggestions.length > 0) && (bestSplitSuggestions[0].merit > 0));
+ bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
+ } // Otherwise, consider which of the splits proposed may be worth trying
+ else {
+ // Determine the hoeffding bound value, used to select how many instances
+ // should be used to make a test decision
+ // to feel reasonably confident that the test chosen by this sample is the
+ // same as what would be chosen using infinite examples
+ double hoeffdingBound = computeHoeffdingBound(1, splitConfidence, getInstancesSeen());
+ // Determine the top two ranked splitting suggestions
+ bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
+ AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2];
+
+ // If the upper bound of the sample mean for the ratio of SDR(best
+ // suggestion) to SDR(second best suggestion),
+ // as determined using the hoeffding bound, is less than 1, then the true
+ // mean is also less than 1, and thus at this
+ // particular moment of observation the bestSuggestion is indeed the best
+ // split option with confidence 1-delta, and
+ // splitting should occur.
+ // Alternatively, if two or more splits are very similar or identical in
+ // terms of their splits, then a threshold limit
+ // (default 0.05) is applied to the hoeffding bound; if the hoeffding
+ // bound is smaller than this limit then the two
+ // competing attributes are equally good, and the split will be made on
+ // the one with the higher SDR value.
+
+ if (bestSuggestion.merit > 0) {
+ if ((((secondBestSuggestion.merit / bestSuggestion.merit) + hoeffdingBound) < 1)
+ || (hoeffdingBound < tieThreshold)) {
+ shouldSplit = true;
+ }
+ }
+ }
+
+ if (shouldSplit) {
+ AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1];
+ double minValue = Double.MAX_VALUE;
+ double[] branchMerits = SDRSplitCriterionAMRules
+ .computeBranchSplitMerits(bestSuggestion.resultingClassDistributions);
+
+ for (int i = 0; i < bestSuggestion.numSplits(); i++) {
+ double value = branchMerits[i];
+ if (value < minValue) {
+ minValue = value;
+ splitIndex = i;
+ statisticsNewRuleActiveLearningNode = bestSuggestion.resultingClassDistributionFromSplit(i);
+ }
+ }
+ statisticsBranchSplit = splitDecision.resultingClassDistributionFromSplit(splitIndex);
+ statisticsOtherBranchSplit = bestSuggestion.resultingClassDistributionFromSplit(splitIndex == 0 ? 1 : 0);
+
+ }
+ return shouldSplit;
+ }
+
+ public AutoExpandVector<AttributeClassObserver> getAttributeObservers() {
+ return this.attributeObservers;
+ }
+
+ public AttributeSplitSuggestion[] getBestSplitSuggestions(SplitCriterion criterion) {
+
+ List<AttributeSplitSuggestion> bestSuggestions = new LinkedList<AttributeSplitSuggestion>();
+
+ // Set the nodeStatistics up as the preSplitDistribution, rather than the
+ // observedClassDistribution
+ double[] nodeSplitDist = this.nodeStatistics.getArrayCopy();
+ for (int i = 0; i < this.attributeObservers.size(); i++) {
+ AttributeClassObserver obs = this.attributeObservers.get(i);
+ if (obs != null) {
+
+ // AT THIS STAGE NON-NUMERIC ATTRIBUTES ARE IGNORED
+ AttributeSplitSuggestion bestSuggestion = null;
+ if (obs instanceof FIMTDDNumericAttributeClassObserver) {
+ bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion, nodeSplitDist, i, true);
+ }
+
+ if (bestSuggestion != null) {
+ bestSuggestions.add(bestSuggestion);
+ }
+ }
+ }
+ return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]);
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveLearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveLearningNode.java
index 4934225..5bc47f5 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveLearningNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveLearningNode.java
@@ -21,13 +21,12 @@
*/
/**
- * Interface for Rule's LearningNode that does not update
- * statistics for expanding rule. It only updates statistics for
- * computing predictions.
+ * Interface for Rule's LearningNode that does not update statistics for
+ * expanding rule. It only updates statistics for computing predictions.
*
* @author Anh Thu Vu
- *
+ *
*/
-public interface RulePassiveLearningNode {
+public interface RulePassiveLearningNode {
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveRegressionNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveRegressionNode.java
index 674e482..086c9e1 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveRegressionNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveRegressionNode.java
@@ -24,53 +24,53 @@
import com.yahoo.labs.samoa.moa.core.DoubleVector;
/**
- * LearningNode for regression rule that does not update
- * statistics for expanding rule. It only updates statistics for
- * computing predictions.
+ * LearningNode for regression rule that does not update statistics for
+ * expanding rule. It only updates statistics for computing predictions.
*
* @author Anh Thu Vu
- *
+ *
*/
public class RulePassiveRegressionNode extends RuleRegressionNode implements RulePassiveLearningNode {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 3720878438856489690L;
-
- public RulePassiveRegressionNode (double[] statistics) {
- super(statistics);
- }
-
- public RulePassiveRegressionNode() {
- super();
- }
-
- public RulePassiveRegressionNode(RuleRegressionNode activeLearningNode) {
- this.predictionFunction = activeLearningNode.predictionFunction;
- this.ruleNumberID = activeLearningNode.ruleNumberID;
- this.nodeStatistics = new DoubleVector(activeLearningNode.nodeStatistics);
- this.learningRatio = activeLearningNode.learningRatio;
- this.perceptron = new Perceptron(activeLearningNode.perceptron, true);
- this.targetMean = new TargetMean(activeLearningNode.targetMean);
- }
-
- /*
- * Update with input instance
- */
- @Override
- public void updateStatistics(Instance inst) {
- // Update the statistics for this node
- // number of instances passing through the node
- nodeStatistics.addToValue(0, 1);
- // sum of y values
- nodeStatistics.addToValue(1, inst.classValue());
- // sum of squared y values
- nodeStatistics.addToValue(2, inst.classValue()*inst.classValue());
-
- this.perceptron.trainOnInstance(inst);
- if (this.predictionFunction != 1) { //Train target mean if prediction function is not Perceptron
- this.targetMean.trainOnInstance(inst);
- }
- }
+ private static final long serialVersionUID = 3720878438856489690L;
+
+ public RulePassiveRegressionNode(double[] statistics) {
+ super(statistics);
+ }
+
+ public RulePassiveRegressionNode() {
+ super();
+ }
+
+ public RulePassiveRegressionNode(RuleRegressionNode activeLearningNode) {
+ this.predictionFunction = activeLearningNode.predictionFunction;
+ this.ruleNumberID = activeLearningNode.ruleNumberID;
+ this.nodeStatistics = new DoubleVector(activeLearningNode.nodeStatistics);
+ this.learningRatio = activeLearningNode.learningRatio;
+ this.perceptron = new Perceptron(activeLearningNode.perceptron, true);
+ this.targetMean = new TargetMean(activeLearningNode.targetMean);
+ }
+
+ /*
+ * Update with input instance
+ */
+ @Override
+ public void updateStatistics(Instance inst) {
+ // Update the statistics for this node
+ // number of instances passing through the node
+ nodeStatistics.addToValue(0, 1);
+ // sum of y values
+ nodeStatistics.addToValue(1, inst.classValue());
+ // sum of squared y values
+ nodeStatistics.addToValue(2, inst.classValue() * inst.classValue());
+
+ this.perceptron.trainOnInstance(inst);
+ if (this.predictionFunction != 1) { // Train target mean if prediction
+ // function is not Perceptron
+ this.targetMean.trainOnInstance(inst);
+ }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleRegressionNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleRegressionNode.java
index 45f5719..068957d 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleRegressionNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleRegressionNode.java
@@ -29,264 +29,268 @@
* The base class for LearningNode for regression rule.
*
* @author Anh Thu Vu
- *
+ *
*/
public abstract class RuleRegressionNode implements Serializable {
-
- private static final long serialVersionUID = 9129659494380381126L;
-
- protected int predictionFunction;
- protected int ruleNumberID;
- // The statistics for this node:
- // Number of instances that have reached it
- // Sum of y values
- // Sum of squared y values
- protected DoubleVector nodeStatistics;
-
- protected Perceptron perceptron;
- protected TargetMean targetMean;
- protected double learningRatio;
-
- /*
- * Simple setters & getters
- */
- public Perceptron getPerceptron() {
- return perceptron;
- }
- public void setPerceptron(Perceptron perceptron) {
- this.perceptron = perceptron;
- }
+ private static final long serialVersionUID = 9129659494380381126L;
- public TargetMean getTargetMean() {
- return targetMean;
- }
+ protected int predictionFunction;
+ protected int ruleNumberID;
+ // The statistics for this node:
+ // Number of instances that have reached it
+ // Sum of y values
+ // Sum of squared y values
+ protected DoubleVector nodeStatistics;
- public void setTargetMean(TargetMean targetMean) {
- this.targetMean = targetMean;
- }
+ protected Perceptron perceptron;
+ protected TargetMean targetMean;
+ protected double learningRatio;
- /*
- * Create a new RuleRegressionNode
- */
- public RuleRegressionNode(double[] initialClassObservations) {
- this.nodeStatistics = new DoubleVector(initialClassObservations);
- }
+ /*
+ * Simple setters & getters
+ */
+ public Perceptron getPerceptron() {
+ return perceptron;
+ }
- public RuleRegressionNode() {
- this(new double[0]);
- }
+ public void setPerceptron(Perceptron perceptron) {
+ this.perceptron = perceptron;
+ }
- /*
- * Update statistics with input instance
- */
- public abstract void updateStatistics(Instance instance);
+ public TargetMean getTargetMean() {
+ return targetMean;
+ }
- /*
- * Predictions
- */
- public double[] getPrediction(Instance instance) {
- int predictionMode = this.getLearnerToUse(this.predictionFunction);
- return getPrediction(instance, predictionMode);
- }
-
- public double[] getSimplePrediction() {
- if( this.targetMean!=null)
- return this.targetMean.getVotesForInstance();
- else
- return new double[]{0};
- }
+ public void setTargetMean(TargetMean targetMean) {
+ this.targetMean = targetMean;
+ }
- public double[] getPrediction(Instance instance, int predictionMode) {
- double[] ret;
- if (predictionMode == 1)
- ret=this.perceptron.getVotesForInstance(instance);
- else
- ret=this.targetMean.getVotesForInstance(instance);
- return ret;
- }
-
- public double getNormalizedPrediction(Instance instance) {
- double res;
- double [] aux;
- switch (this.predictionFunction) {
- //perceptron - 1
- case 1:
- res=this.perceptron.normalizedPrediction(instance);
- break;
- //target mean - 2
- case 2:
- aux=this.targetMean.getVotesForInstance();
- res=normalize(aux[0]);
- break;
- //adaptive - 0
- case 0:
- int predictionMode = this.getLearnerToUse(0);
- if(predictionMode == 1)
- {
- res=this.perceptron.normalizedPrediction(instance);
- }
- else{
- aux=this.targetMean.getVotesForInstance(instance);
- res = normalize(aux[0]);
- }
- break;
- default:
- throw new UnsupportedOperationException("Prediction mode not in range.");
- }
- return res;
- }
+ /*
+ * Create a new RuleRegressionNode
+ */
+ public RuleRegressionNode(double[] initialClassObservations) {
+ this.nodeStatistics = new DoubleVector(initialClassObservations);
+ }
- /*
- * Get learner mode
- */
- public int getLearnerToUse(int predMode) {
- int predictionMode = predMode;
- if (predictionMode == 0) {
- double perceptronError= this.perceptron.getCurrentError();
- double meanTargetError =this.targetMean.getCurrentError();
- if (perceptronError < meanTargetError)
- predictionMode = 1; //PERCEPTRON
- else
- predictionMode = 2; //TARGET MEAN
- }
- return predictionMode;
- }
+ public RuleRegressionNode() {
+ this(new double[0]);
+ }
- /*
- * Error and change detection
- */
- public double computeError(Instance instance) {
- double normalizedPrediction = getNormalizedPrediction(instance);
- double normalizedClassValue = normalize(instance.classValue());
- return Math.abs(normalizedClassValue - normalizedPrediction);
- }
-
- public double getCurrentError() {
- double error;
- if (this.perceptron!=null){
- if (targetMean==null)
- error=perceptron.getCurrentError();
- else{
- double errorP=perceptron.getCurrentError();
- double errorTM=targetMean.getCurrentError();
- error = (errorP<errorTM) ? errorP : errorTM;
- }
- }
- else
- error=Double.MAX_VALUE;
- return error;
- }
-
- /*
- * no. of instances seen
- */
- public long getInstancesSeen() {
- if (nodeStatistics != null) {
- return (long)this.nodeStatistics.getValue(0);
- } else {
- return 0;
- }
- }
+ /*
+ * Update statistics with input instance
+ */
+ public abstract void updateStatistics(Instance instance);
- public DoubleVector getNodeStatistics(){
- return this.nodeStatistics;
- }
-
- /*
- * Anomaly detection
- */
- public boolean isAnomaly(Instance instance,
- double uniVariateAnomalyProbabilityThreshold,
- double multiVariateAnomalyProbabilityThreshold,
- int numberOfInstanceesForAnomaly) {
- //AMRUles is equipped with anomaly detection. If on, compute the anomaly value.
- long perceptronIntancesSeen=this.perceptron.getInstancesSeen();
- if ( perceptronIntancesSeen>= numberOfInstanceesForAnomaly) {
- double attribSum;
- double attribSquaredSum;
- double D = 0.0;
- double N = 0.0;
- double anomaly;
+ /*
+ * Predictions
+ */
+ public double[] getPrediction(Instance instance) {
+ int predictionMode = this.getLearnerToUse(this.predictionFunction);
+ return getPrediction(instance, predictionMode);
+ }
- for (int x = 0; x < instance.numAttributes() - 1; x++) {
- // Perceptron is initialized each rule.
- // this is a local anomaly.
- int instAttIndex = modelAttIndexToInstanceAttIndex(x, instance);
- attribSum = this.perceptron.perceptronattributeStatistics.getValue(x);
- attribSquaredSum = this.perceptron.squaredperceptronattributeStatistics.getValue(x);
- double mean = attribSum / perceptronIntancesSeen;
- double sd = computeSD(attribSquaredSum, attribSum, perceptronIntancesSeen);
- double probability = computeProbability(mean, sd, instance.value(instAttIndex));
+ public double[] getSimplePrediction() {
+ if (this.targetMean != null)
+ return this.targetMean.getVotesForInstance();
+ else
+ return new double[] { 0 };
+ }
- if (probability > 0.0) {
- D = D + Math.abs(Math.log(probability));
- if (probability < uniVariateAnomalyProbabilityThreshold) {//0.10
- N = N + Math.abs(Math.log(probability));
- }
- }
- }
+ public double[] getPrediction(Instance instance, int predictionMode) {
+ double[] ret;
+ if (predictionMode == 1)
+ ret = this.perceptron.getVotesForInstance(instance);
+ else
+ ret = this.targetMean.getVotesForInstance(instance);
+ return ret;
+ }
- anomaly = 0.0;
- if (D != 0.0) {
- anomaly = N / D;
- }
- if (anomaly >= multiVariateAnomalyProbabilityThreshold) {
- //debuganomaly(instance,
- // uniVariateAnomalyProbabilityThreshold,
- // multiVariateAnomalyProbabilityThreshold,
- // anomaly);
- return true;
- }
- }
- return false;
- }
-
- /*
- * Helpers
- */
- public static double computeProbability(double mean, double sd, double value) {
- double probability = 0.0;
-
- if (sd > 0.0) {
- double k = (Math.abs(value - mean) / sd); // One tailed variant of Chebyshev's inequality
- probability= 1.0 / (1+k*k);
- }
-
- return probability;
- }
-
- public static double computeHoeffdingBound(double range, double confidence, double n) {
- return Math.sqrt(((range * range) * Math.log(1.0 / confidence)) / (2.0 * n));
- }
-
- private double normalize(double value) {
- double meanY = this.nodeStatistics.getValue(1)/this.nodeStatistics.getValue(0);
- double sdY = computeSD(this.nodeStatistics.getValue(2), this.nodeStatistics.getValue(1), (long)this.nodeStatistics.getValue(0));
- double normalizedY = 0.0;
- if (sdY > 0.0000001) {
- normalizedY = (value - meanY) / (sdY);
- }
- return normalizedY;
- }
-
-
- public double computeSD(double squaredVal, double val, long size) {
- if (size > 1) {
- return Math.sqrt((squaredVal - ((val * val) / size)) / (size - 1.0));
- }
- return 0.0;
- }
-
- /**
- * Gets the index of the attribute in the instance,
- * given the index of the attribute in the learner.
- *
- * @param index the index of the attribute in the learner
- * @param inst the instance
- * @return the index in the instance
- */
- protected static int modelAttIndexToInstanceAttIndex(int index, Instance inst) {
- return index<= inst.classIndex() ? index : index + 1;
+ public double getNormalizedPrediction(Instance instance) {
+ double res;
+ double[] aux;
+ switch (this.predictionFunction) {
+ // perceptron - 1
+ case 1:
+ res = this.perceptron.normalizedPrediction(instance);
+ break;
+ // target mean - 2
+ case 2:
+ aux = this.targetMean.getVotesForInstance();
+ res = normalize(aux[0]);
+ break;
+ // adaptive - 0
+ case 0:
+ int predictionMode = this.getLearnerToUse(0);
+ if (predictionMode == 1)
+ {
+ res = this.perceptron.normalizedPrediction(instance);
+ }
+ else {
+ aux = this.targetMean.getVotesForInstance(instance);
+ res = normalize(aux[0]);
+ }
+ break;
+ default:
+ throw new UnsupportedOperationException("Prediction mode not in range.");
}
+ return res;
+ }
+
+ /*
+ * Get learner mode
+ */
+ public int getLearnerToUse(int predMode) {
+ int predictionMode = predMode;
+ if (predictionMode == 0) {
+ double perceptronError = this.perceptron.getCurrentError();
+ double meanTargetError = this.targetMean.getCurrentError();
+ if (perceptronError < meanTargetError)
+ predictionMode = 1; // PERCEPTRON
+ else
+ predictionMode = 2; // TARGET MEAN
+ }
+ return predictionMode;
+ }
+
+ /*
+ * Error and change detection
+ */
+ public double computeError(Instance instance) {
+ double normalizedPrediction = getNormalizedPrediction(instance);
+ double normalizedClassValue = normalize(instance.classValue());
+ return Math.abs(normalizedClassValue - normalizedPrediction);
+ }
+
+ public double getCurrentError() {
+ double error;
+ if (this.perceptron != null) {
+ if (targetMean == null)
+ error = perceptron.getCurrentError();
+ else {
+ double errorP = perceptron.getCurrentError();
+ double errorTM = targetMean.getCurrentError();
+ error = (errorP < errorTM) ? errorP : errorTM;
+ }
+ }
+ else
+ error = Double.MAX_VALUE;
+ return error;
+ }
+
+ /*
+ * no. of instances seen
+ */
+ public long getInstancesSeen() {
+ if (nodeStatistics != null) {
+ return (long) this.nodeStatistics.getValue(0);
+ } else {
+ return 0;
+ }
+ }
+
+ public DoubleVector getNodeStatistics() {
+ return this.nodeStatistics;
+ }
+
+ /*
+ * Anomaly detection
+ */
+ public boolean isAnomaly(Instance instance,
+ double uniVariateAnomalyProbabilityThreshold,
+ double multiVariateAnomalyProbabilityThreshold,
+ int numberOfInstanceesForAnomaly) {
+ // AMRUles is equipped with anomaly detection. If on, compute the anomaly
+ // value.
+ long perceptronIntancesSeen = this.perceptron.getInstancesSeen();
+ if (perceptronIntancesSeen >= numberOfInstanceesForAnomaly) {
+ double attribSum;
+ double attribSquaredSum;
+ double D = 0.0;
+ double N = 0.0;
+ double anomaly;
+
+ for (int x = 0; x < instance.numAttributes() - 1; x++) {
+ // Perceptron is initialized each rule.
+ // this is a local anomaly.
+ int instAttIndex = modelAttIndexToInstanceAttIndex(x, instance);
+ attribSum = this.perceptron.perceptronattributeStatistics.getValue(x);
+ attribSquaredSum = this.perceptron.squaredperceptronattributeStatistics.getValue(x);
+ double mean = attribSum / perceptronIntancesSeen;
+ double sd = computeSD(attribSquaredSum, attribSum, perceptronIntancesSeen);
+ double probability = computeProbability(mean, sd, instance.value(instAttIndex));
+
+ if (probability > 0.0) {
+ D = D + Math.abs(Math.log(probability));
+ if (probability < uniVariateAnomalyProbabilityThreshold) {// 0.10
+ N = N + Math.abs(Math.log(probability));
+ }
+ }
+ }
+
+ anomaly = 0.0;
+ if (D != 0.0) {
+ anomaly = N / D;
+ }
+ if (anomaly >= multiVariateAnomalyProbabilityThreshold) {
+ // debuganomaly(instance,
+ // uniVariateAnomalyProbabilityThreshold,
+ // multiVariateAnomalyProbabilityThreshold,
+ // anomaly);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /*
+ * Helpers
+ */
+ public static double computeProbability(double mean, double sd, double value) {
+ double probability = 0.0;
+
+ if (sd > 0.0) {
+ double k = (Math.abs(value - mean) / sd); // One tailed variant of
+ // Chebyshev's inequality
+ probability = 1.0 / (1 + k * k);
+ }
+
+ return probability;
+ }
+
+ public static double computeHoeffdingBound(double range, double confidence, double n) {
+ return Math.sqrt(((range * range) * Math.log(1.0 / confidence)) / (2.0 * n));
+ }
+
+ private double normalize(double value) {
+ double meanY = this.nodeStatistics.getValue(1) / this.nodeStatistics.getValue(0);
+ double sdY = computeSD(this.nodeStatistics.getValue(2), this.nodeStatistics.getValue(1),
+ (long) this.nodeStatistics.getValue(0));
+ double normalizedY = 0.0;
+ if (sdY > 0.0000001) {
+ normalizedY = (value - meanY) / (sdY);
+ }
+ return normalizedY;
+ }
+
+ public double computeSD(double squaredVal, double val, long size) {
+ if (size > 1) {
+ return Math.sqrt((squaredVal - ((val * val) / size)) / (size - 1.0));
+ }
+ return 0.0;
+ }
+
+ /**
+ * Gets the index of the attribute in the instance, given the index of the
+ * attribute in the learner.
+ *
+ * @param index
+ * the index of the attribute in the learner
+ * @param inst
+ * the instance
+ * @return the index in the instance
+ */
+ protected static int modelAttIndexToInstanceAttIndex(int index, Instance inst) {
+ return index <= inst.classIndex() ? index : index + 1;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleSplitNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleSplitNode.java
index 28f4890..a89345e 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleSplitNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleSplitNode.java
@@ -30,37 +30,39 @@
* Represent a feature of rules (an element of ruleÅ› nodeList).
*
* @author Anh Thu Vu
- *
+ *
*/
public class RuleSplitNode extends SplitNode {
- protected double lastTargetMean;
- protected int operatorObserver;
+ protected double lastTargetMean;
+ protected int operatorObserver;
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public InstanceConditionalTest getSplitTest() {
- return this.splitTest;
- }
+ public InstanceConditionalTest getSplitTest() {
+ return this.splitTest;
+ }
- /**
- * Create a new RuleSplitNode
- */
- public RuleSplitNode() {
- this(null, new double[0]);
- }
- public RuleSplitNode(InstanceConditionalTest splitTest, double[] classObservations) {
- super(splitTest, classObservations);
- }
-
- public RuleSplitNode getACopy() {
- InstanceConditionalTest splitTest = new NumericAttributeBinaryRulePredicate((NumericAttributeBinaryRulePredicate) this.getSplitTest());
- return new RuleSplitNode(splitTest, this.getObservedClassDistribution());
- }
+ /**
+ * Create a new RuleSplitNode
+ */
+ public RuleSplitNode() {
+ this(null, new double[0]);
+ }
- public boolean evaluate(Instance instance) {
- Predicate predicate = (Predicate) this.splitTest;
- return predicate.evaluate(instance);
- }
+ public RuleSplitNode(InstanceConditionalTest splitTest, double[] classObservations) {
+ super(splitTest, classObservations);
+ }
+
+ public RuleSplitNode getACopy() {
+ InstanceConditionalTest splitTest = new NumericAttributeBinaryRulePredicate(
+ (NumericAttributeBinaryRulePredicate) this.getSplitTest());
+ return new RuleSplitNode(splitTest, this.getObservedClassDistribution());
+ }
+
+ public boolean evaluate(Instance instance) {
+ Predicate predicate = (Predicate) this.splitTest;
+ return predicate.evaluate(instance);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/TargetMean.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/TargetMean.java
index 902acf0..da00c0d 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/TargetMean.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/TargetMean.java
@@ -59,162 +59,164 @@
public class TargetMean extends AbstractClassifier implements Regressor {
- /**
+ /**
*
*/
- protected long n;
- protected double sum;
- protected double errorSum;
- protected double nError;
- private double fadingErrorFactor;
+ protected long n;
+ protected double sum;
+ protected double errorSum;
+ protected double nError;
+ private double fadingErrorFactor;
- private static final long serialVersionUID = 7152547322803559115L;
+ private static final long serialVersionUID = 7152547322803559115L;
- public FloatOption fadingErrorFactorOption = new FloatOption(
- "fadingErrorFactor", 'e',
- "Fading error factor for the TargetMean accumulated error", 0.99, 0, 1);
+ public FloatOption fadingErrorFactorOption = new FloatOption(
+ "fadingErrorFactor", 'e',
+ "Fading error factor for the TargetMean accumulated error", 0.99, 0, 1);
- @Override
- public boolean isRandomizable() {
- return false;
- }
+ @Override
+ public boolean isRandomizable() {
+ return false;
+ }
- @Override
- public double[] getVotesForInstance(Instance inst) {
- return getVotesForInstance();
- }
-
- public double[] getVotesForInstance() {
- double[] currentMean=new double[1];
- if (n>0)
- currentMean[0]=sum/n;
- else
- currentMean[0]=0;
- return currentMean;
- }
+ @Override
+ public double[] getVotesForInstance(Instance inst) {
+ return getVotesForInstance();
+ }
- @Override
- public void resetLearningImpl() {
- sum=0;
- n=0;
- errorSum=Double.MAX_VALUE;
- nError=0;
- }
+ public double[] getVotesForInstance() {
+ double[] currentMean = new double[1];
+ if (n > 0)
+ currentMean[0] = sum / n;
+ else
+ currentMean[0] = 0;
+ return currentMean;
+ }
- @Override
- public void trainOnInstanceImpl(Instance inst) {
- updateAccumulatedError(inst);
- ++this.n;
- this.sum+=inst.classValue();
- }
- protected void updateAccumulatedError(Instance inst){
- double mean=0;
- nError=1+fadingErrorFactor*nError;
- if(n>0)
- mean=sum/n;
- errorSum=Math.abs(inst.classValue()-mean)+fadingErrorFactor*errorSum;
- }
+ @Override
+ public void resetLearningImpl() {
+ sum = 0;
+ n = 0;
+ errorSum = Double.MAX_VALUE;
+ nError = 0;
+ }
- @Override
- protected Measurement[] getModelMeasurementsImpl() {
- return null;
- }
+ @Override
+ public void trainOnInstanceImpl(Instance inst) {
+ updateAccumulatedError(inst);
+ ++this.n;
+ this.sum += inst.classValue();
+ }
- @Override
- public void getModelDescription(StringBuilder out, int indent) {
- StringUtils.appendIndented(out, indent, "Current Mean: " + this.sum/this.n);
- StringUtils.appendNewline(out);
+ protected void updateAccumulatedError(Instance inst) {
+ double mean = 0;
+ nError = 1 + fadingErrorFactor * nError;
+ if (n > 0)
+ mean = sum / n;
+ errorSum = Math.abs(inst.classValue() - mean) + fadingErrorFactor * errorSum;
+ }
- }
- /* JD
- * Resets the learner but initializes with a starting point
- * */
- public void reset(double currentMean, long numberOfInstances) {
- this.sum=currentMean*numberOfInstances;
- this.n=numberOfInstances;
- this.resetError();
- }
+ @Override
+ protected Measurement[] getModelMeasurementsImpl() {
+ return null;
+ }
- /* JD
- * Resets the learner but initializes with a starting point
- * */
- public double getCurrentError(){
- if(this.nError>0)
- return this.errorSum/this.nError;
- else
- return Double.MAX_VALUE;
- }
+ @Override
+ public void getModelDescription(StringBuilder out, int indent) {
+ StringUtils.appendIndented(out, indent, "Current Mean: " + this.sum / this.n);
+ StringUtils.appendNewline(out);
- public TargetMean(TargetMean t) {
- super();
- this.n = t.n;
- this.sum = t.sum;
- this.errorSum = t.errorSum;
- this.nError = t.nError;
- this.fadingErrorFactor = t.fadingErrorFactor;
- this.fadingErrorFactorOption = t.fadingErrorFactorOption;
- }
-
- public TargetMean(TargetMeanData td) {
- this();
- this.n = td.n;
- this.sum = td.sum;
- this.errorSum = td.errorSum;
- this.nError = td.nError;
- this.fadingErrorFactor = td.fadingErrorFactor;
- this.fadingErrorFactorOption.setValue(td.fadingErrorFactorOptionValue);
- }
+ }
- public TargetMean() {
- super();
- fadingErrorFactor=fadingErrorFactorOption.getValue();
- }
+ /*
+ * JD Resets the learner but initializes with a starting point
+ */
+ public void reset(double currentMean, long numberOfInstances) {
+ this.sum = currentMean * numberOfInstances;
+ this.n = numberOfInstances;
+ this.resetError();
+ }
- public void resetError() {
- this.errorSum=0;
- this.nError=0;
- }
-
- public static class TargetMeanData {
- private long n;
- private double sum;
- private double errorSum;
- private double nError;
- private double fadingErrorFactor;
- private double fadingErrorFactorOptionValue;
-
- public TargetMeanData() {
-
- }
-
- public TargetMeanData(TargetMean tm) {
- this.n = tm.n;
- this.sum = tm.sum;
- this.errorSum = tm.errorSum;
- this.nError = tm.nError;
- this.fadingErrorFactor = tm.fadingErrorFactor;
- if (tm.fadingErrorFactorOption != null)
- this.fadingErrorFactorOptionValue = tm.fadingErrorFactorOption.getValue();
- else
- this.fadingErrorFactorOptionValue = 0.99;
- }
-
- public TargetMean build() {
- return new TargetMean(this);
- }
- }
-
- public static final class TargetMeanSerializer extends Serializer<TargetMean>{
+ /*
+ * JD Resets the learner but initializes with a starting point
+ */
+ public double getCurrentError() {
+ if (this.nError > 0)
+ return this.errorSum / this.nError;
+ else
+ return Double.MAX_VALUE;
+ }
- @Override
- public void write(Kryo kryo, Output output, TargetMean t) {
- kryo.writeObjectOrNull(output, new TargetMeanData(t), TargetMeanData.class);
- }
+ public TargetMean(TargetMean t) {
+ super();
+ this.n = t.n;
+ this.sum = t.sum;
+ this.errorSum = t.errorSum;
+ this.nError = t.nError;
+ this.fadingErrorFactor = t.fadingErrorFactor;
+ this.fadingErrorFactorOption = t.fadingErrorFactorOption;
+ }
- @Override
- public TargetMean read(Kryo kryo, Input input, Class<TargetMean> type) {
- TargetMeanData data = kryo.readObjectOrNull(input, TargetMeanData.class);
- return data.build();
- }
- }
+ public TargetMean(TargetMeanData td) {
+ this();
+ this.n = td.n;
+ this.sum = td.sum;
+ this.errorSum = td.errorSum;
+ this.nError = td.nError;
+ this.fadingErrorFactor = td.fadingErrorFactor;
+ this.fadingErrorFactorOption.setValue(td.fadingErrorFactorOptionValue);
+ }
+
+ public TargetMean() {
+ super();
+ fadingErrorFactor = fadingErrorFactorOption.getValue();
+ }
+
+ public void resetError() {
+ this.errorSum = 0;
+ this.nError = 0;
+ }
+
+ public static class TargetMeanData {
+ private long n;
+ private double sum;
+ private double errorSum;
+ private double nError;
+ private double fadingErrorFactor;
+ private double fadingErrorFactorOptionValue;
+
+ public TargetMeanData() {
+
+ }
+
+ public TargetMeanData(TargetMean tm) {
+ this.n = tm.n;
+ this.sum = tm.sum;
+ this.errorSum = tm.errorSum;
+ this.nError = tm.nError;
+ this.fadingErrorFactor = tm.fadingErrorFactor;
+ if (tm.fadingErrorFactorOption != null)
+ this.fadingErrorFactorOptionValue = tm.fadingErrorFactorOption.getValue();
+ else
+ this.fadingErrorFactorOptionValue = 0.99;
+ }
+
+ public TargetMean build() {
+ return new TargetMean(this);
+ }
+ }
+
+ public static final class TargetMeanSerializer extends Serializer<TargetMean> {
+
+ @Override
+ public void write(Kryo kryo, Output output, TargetMean t) {
+ kryo.writeObjectOrNull(output, new TargetMeanData(t), TargetMeanData.class);
+ }
+
+ @Override
+ public TargetMean read(Kryo kryo, Input input, Class<TargetMean> type) {
+ TargetMeanData data = kryo.readObjectOrNull(input, TargetMeanData.class);
+ return data.build();
+ }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java
index 54a4006..007c535 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java
@@ -39,296 +39,300 @@
* Default Rule Learner Processor (HAMR).
*
* @author Anh Thu Vu
- *
+ *
*/
public class AMRDefaultRuleProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 23702084591044447L;
-
- private static final Logger logger =
- LoggerFactory.getLogger(AMRDefaultRuleProcessor.class);
+ private static final long serialVersionUID = 23702084591044447L;
- private int processorId;
+ private static final Logger logger =
+ LoggerFactory.getLogger(AMRDefaultRuleProcessor.class);
- // Default rule
- protected transient ActiveRule defaultRule;
- protected transient int ruleNumberID;
- protected transient double[] statistics;
+ private int processorId;
- // SAMOA Stream
- private Stream ruleStream;
- private Stream resultStream;
+ // Default rule
+ protected transient ActiveRule defaultRule;
+ protected transient int ruleNumberID;
+ protected transient double[] statistics;
- // Options
- protected int pageHinckleyThreshold;
- protected double pageHinckleyAlpha;
- protected boolean driftDetection;
- protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
- protected boolean constantLearningRatioDecay;
- protected double learningRatio;
+ // SAMOA Stream
+ private Stream ruleStream;
+ private Stream resultStream;
- protected double splitConfidence;
- protected double tieThreshold;
- protected int gracePeriod;
+ // Options
+ protected int pageHinckleyThreshold;
+ protected double pageHinckleyAlpha;
+ protected boolean driftDetection;
+ protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
+ protected boolean constantLearningRatioDecay;
+ protected double learningRatio;
- protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
-
- /*
- * Constructor
- */
- public AMRDefaultRuleProcessor (Builder builder) {
- this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
- this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
- this.driftDetection = builder.driftDetection;
- this.predictionFunction = builder.predictionFunction;
- this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
- this.learningRatio = builder.learningRatio;
- this.splitConfidence = builder.splitConfidence;
- this.tieThreshold = builder.tieThreshold;
- this.gracePeriod = builder.gracePeriod;
-
- this.numericObserver = builder.numericObserver;
- }
-
- @Override
- public boolean process(ContentEvent event) {
- InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
- // predict
- if (instanceEvent.isTesting()) {
- this.predictOnInstance(instanceEvent);
- }
+ protected double splitConfidence;
+ protected double tieThreshold;
+ protected int gracePeriod;
- // train
- if (instanceEvent.isTraining()) {
- this.trainOnInstance(instanceEvent);
- }
-
- return false;
- }
-
- /*
- * Prediction
- */
- private void predictOnInstance (InstanceContentEvent instanceEvent) {
- double [] vote=defaultRule.getPrediction(instanceEvent.getInstance());
- ResultContentEvent rce = newResultContentEvent(vote, instanceEvent);
- resultStream.put(rce);
- }
+ protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
- private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){
- ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent());
- rce.setClassifierIndex(this.processorId);
- rce.setEvaluationIndex(inEvent.getEvaluationIndex());
- return rce;
- }
-
- /*
- * Training
- */
- private void trainOnInstance (InstanceContentEvent instanceEvent) {
- this.trainOnInstanceImpl(instanceEvent.getInstance());
- }
- public void trainOnInstanceImpl(Instance instance) {
- defaultRule.updateStatistics(instance);
- if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
- if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
- ActiveRule newDefaultRule=newRule(defaultRule.getRuleNumberID(),(RuleActiveRegressionNode)defaultRule.getLearningNode(),
- ((RuleActiveRegressionNode)defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); //other branch
- defaultRule.split();
- defaultRule.setRuleNumberID(++ruleNumberID);
- // send out the new rule
- sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule);
- defaultRule=newDefaultRule;
- }
- }
- }
-
- /*
- * Create new rules
- */
- private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
- ActiveRule r=newRule(ID);
+ /*
+ * Constructor
+ */
+ public AMRDefaultRuleProcessor(Builder builder) {
+ this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
+ this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
+ this.driftDetection = builder.driftDetection;
+ this.predictionFunction = builder.predictionFunction;
+ this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
+ this.learningRatio = builder.learningRatio;
+ this.splitConfidence = builder.splitConfidence;
+ this.tieThreshold = builder.tieThreshold;
+ this.gracePeriod = builder.gracePeriod;
- if (node!=null)
- {
- if(node.getPerceptron()!=null)
- {
- r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
- r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
- }
- if (statistics==null)
- {
- double mean;
- if(node.getNodeStatistics().getValue(0)>0){
- mean=node.getNodeStatistics().getValue(1)/node.getNodeStatistics().getValue(0);
- r.getLearningNode().getTargetMean().reset(mean, 1);
- }
- }
- }
- if (statistics!=null && ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean()!=null)
- {
- double mean;
- if(statistics[0]>0){
- mean=statistics[1]/statistics[0];
- ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean().reset(mean, (long)statistics[0]);
- }
- }
- return r;
- }
+ this.numericObserver = builder.numericObserver;
+ }
- private ActiveRule newRule(int ID) {
- ActiveRule r=new ActiveRule.Builder().
- threshold(this.pageHinckleyThreshold).
- alpha(this.pageHinckleyAlpha).
- changeDetection(this.driftDetection).
- predictionFunction(this.predictionFunction).
- statistics(new double[3]).
- learningRatio(this.learningRatio).
- numericObserver(numericObserver).
- id(ID).build();
- return r;
- }
-
- @Override
- public void onCreate(int id) {
- this.processorId = id;
- this.statistics= new double[]{0.0,0,0};
- this.ruleNumberID=0;
- this.defaultRule = newRule(++this.ruleNumberID);
- }
+ @Override
+ public boolean process(ContentEvent event) {
+ InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
+ // predict
+ if (instanceEvent.isTesting()) {
+ this.predictOnInstance(instanceEvent);
+ }
- /*
- * Clone processor
- */
- @Override
- public Processor newProcessor(Processor p) {
- AMRDefaultRuleProcessor oldProcessor = (AMRDefaultRuleProcessor) p;
- Builder builder = new Builder(oldProcessor);
- AMRDefaultRuleProcessor newProcessor = builder.build();
- newProcessor.resultStream = oldProcessor.resultStream;
- newProcessor.ruleStream = oldProcessor.ruleStream;
- return newProcessor;
- }
-
- /*
- * Send events
- */
- private void sendAddRuleEvent(int ruleID, ActiveRule rule) {
- RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false);
- this.ruleStream.put(rce);
- }
-
- /*
- * Output streams
- */
- public void setRuleStream(Stream ruleStream) {
- this.ruleStream = ruleStream;
- }
-
- public Stream getRuleStream() {
- return this.ruleStream;
- }
-
- public void setResultStream(Stream resultStream) {
- this.resultStream = resultStream;
- }
-
- public Stream getResultStream() {
- return this.resultStream;
- }
-
- /*
- * Builder
- */
- public static class Builder {
- private int pageHinckleyThreshold;
- private double pageHinckleyAlpha;
- private boolean driftDetection;
- private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
- private boolean constantLearningRatioDecay;
- private double learningRatio;
- private double splitConfidence;
- private double tieThreshold;
- private int gracePeriod;
-
- private FIMTDDNumericAttributeClassLimitObserver numericObserver;
-
- private Instances dataset;
-
- public Builder(Instances dataset){
- this.dataset = dataset;
- }
-
- public Builder(AMRDefaultRuleProcessor processor) {
- this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
- this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
- this.driftDetection = processor.driftDetection;
- this.predictionFunction = processor.predictionFunction;
- this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
- this.learningRatio = processor.learningRatio;
- this.splitConfidence = processor.splitConfidence;
- this.tieThreshold = processor.tieThreshold;
- this.gracePeriod = processor.gracePeriod;
-
- this.numericObserver = processor.numericObserver;
- }
-
- public Builder threshold(int threshold) {
- this.pageHinckleyThreshold = threshold;
- return this;
- }
-
- public Builder alpha(double alpha) {
- this.pageHinckleyAlpha = alpha;
- return this;
- }
-
- public Builder changeDetection(boolean changeDetection) {
- this.driftDetection = changeDetection;
- return this;
- }
-
- public Builder predictionFunction(int predictionFunction) {
- this.predictionFunction = predictionFunction;
- return this;
- }
-
- public Builder constantLearningRatioDecay(boolean constantDecay) {
- this.constantLearningRatioDecay = constantDecay;
- return this;
- }
-
- public Builder learningRatio(double learningRatio) {
- this.learningRatio = learningRatio;
- return this;
- }
-
- public Builder splitConfidence(double splitConfidence) {
- this.splitConfidence = splitConfidence;
- return this;
- }
-
- public Builder tieThreshold(double tieThreshold) {
- this.tieThreshold = tieThreshold;
- return this;
- }
-
- public Builder gracePeriod(int gracePeriod) {
- this.gracePeriod = gracePeriod;
- return this;
- }
-
- public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
- this.numericObserver = numericObserver;
- return this;
- }
-
- public AMRDefaultRuleProcessor build() {
- return new AMRDefaultRuleProcessor(this);
- }
- }
-
+ // train
+ if (instanceEvent.isTraining()) {
+ this.trainOnInstance(instanceEvent);
+ }
+
+ return false;
+ }
+
+ /*
+ * Prediction
+ */
+ private void predictOnInstance(InstanceContentEvent instanceEvent) {
+ double[] vote = defaultRule.getPrediction(instanceEvent.getInstance());
+ ResultContentEvent rce = newResultContentEvent(vote, instanceEvent);
+ resultStream.put(rce);
+ }
+
+ private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
+ ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
+ inEvent.getClassId(), prediction, inEvent.isLastEvent());
+ rce.setClassifierIndex(this.processorId);
+ rce.setEvaluationIndex(inEvent.getEvaluationIndex());
+ return rce;
+ }
+
+ /*
+ * Training
+ */
+ private void trainOnInstance(InstanceContentEvent instanceEvent) {
+ this.trainOnInstanceImpl(instanceEvent.getInstance());
+ }
+
+ public void trainOnInstanceImpl(Instance instance) {
+ defaultRule.updateStatistics(instance);
+ if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
+ if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
+ ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(),
+ (RuleActiveRegressionNode) defaultRule.getLearningNode(),
+ ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other
+ // branch
+ defaultRule.split();
+ defaultRule.setRuleNumberID(++ruleNumberID);
+ // send out the new rule
+ sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule);
+ defaultRule = newDefaultRule;
+ }
+ }
+ }
+
+ /*
+ * Create new rules
+ */
+ private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
+ ActiveRule r = newRule(ID);
+
+ if (node != null)
+ {
+ if (node.getPerceptron() != null)
+ {
+ r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
+ r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
+ }
+ if (statistics == null)
+ {
+ double mean;
+ if (node.getNodeStatistics().getValue(0) > 0) {
+ mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0);
+ r.getLearningNode().getTargetMean().reset(mean, 1);
+ }
+ }
+ }
+ if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null)
+ {
+ double mean;
+ if (statistics[0] > 0) {
+ mean = statistics[1] / statistics[0];
+ ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]);
+ }
+ }
+ return r;
+ }
+
+ private ActiveRule newRule(int ID) {
+ ActiveRule r = new ActiveRule.Builder().
+ threshold(this.pageHinckleyThreshold).
+ alpha(this.pageHinckleyAlpha).
+ changeDetection(this.driftDetection).
+ predictionFunction(this.predictionFunction).
+ statistics(new double[3]).
+ learningRatio(this.learningRatio).
+ numericObserver(numericObserver).
+ id(ID).build();
+ return r;
+ }
+
+ @Override
+ public void onCreate(int id) {
+ this.processorId = id;
+ this.statistics = new double[] { 0.0, 0, 0 };
+ this.ruleNumberID = 0;
+ this.defaultRule = newRule(++this.ruleNumberID);
+ }
+
+ /*
+ * Clone processor
+ */
+ @Override
+ public Processor newProcessor(Processor p) {
+ AMRDefaultRuleProcessor oldProcessor = (AMRDefaultRuleProcessor) p;
+ Builder builder = new Builder(oldProcessor);
+ AMRDefaultRuleProcessor newProcessor = builder.build();
+ newProcessor.resultStream = oldProcessor.resultStream;
+ newProcessor.ruleStream = oldProcessor.ruleStream;
+ return newProcessor;
+ }
+
+ /*
+ * Send events
+ */
+ private void sendAddRuleEvent(int ruleID, ActiveRule rule) {
+ RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false);
+ this.ruleStream.put(rce);
+ }
+
+ /*
+ * Output streams
+ */
+ public void setRuleStream(Stream ruleStream) {
+ this.ruleStream = ruleStream;
+ }
+
+ public Stream getRuleStream() {
+ return this.ruleStream;
+ }
+
+ public void setResultStream(Stream resultStream) {
+ this.resultStream = resultStream;
+ }
+
+ public Stream getResultStream() {
+ return this.resultStream;
+ }
+
+ /*
+ * Builder
+ */
+ public static class Builder {
+ private int pageHinckleyThreshold;
+ private double pageHinckleyAlpha;
+ private boolean driftDetection;
+ private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
+ private boolean constantLearningRatioDecay;
+ private double learningRatio;
+ private double splitConfidence;
+ private double tieThreshold;
+ private int gracePeriod;
+
+ private FIMTDDNumericAttributeClassLimitObserver numericObserver;
+
+ private Instances dataset;
+
+ public Builder(Instances dataset) {
+ this.dataset = dataset;
+ }
+
+ public Builder(AMRDefaultRuleProcessor processor) {
+ this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
+ this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
+ this.driftDetection = processor.driftDetection;
+ this.predictionFunction = processor.predictionFunction;
+ this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
+ this.learningRatio = processor.learningRatio;
+ this.splitConfidence = processor.splitConfidence;
+ this.tieThreshold = processor.tieThreshold;
+ this.gracePeriod = processor.gracePeriod;
+
+ this.numericObserver = processor.numericObserver;
+ }
+
+ public Builder threshold(int threshold) {
+ this.pageHinckleyThreshold = threshold;
+ return this;
+ }
+
+ public Builder alpha(double alpha) {
+ this.pageHinckleyAlpha = alpha;
+ return this;
+ }
+
+ public Builder changeDetection(boolean changeDetection) {
+ this.driftDetection = changeDetection;
+ return this;
+ }
+
+ public Builder predictionFunction(int predictionFunction) {
+ this.predictionFunction = predictionFunction;
+ return this;
+ }
+
+ public Builder constantLearningRatioDecay(boolean constantDecay) {
+ this.constantLearningRatioDecay = constantDecay;
+ return this;
+ }
+
+ public Builder learningRatio(double learningRatio) {
+ this.learningRatio = learningRatio;
+ return this;
+ }
+
+ public Builder splitConfidence(double splitConfidence) {
+ this.splitConfidence = splitConfidence;
+ return this;
+ }
+
+ public Builder tieThreshold(double tieThreshold) {
+ this.tieThreshold = tieThreshold;
+ return this;
+ }
+
+ public Builder gracePeriod(int gracePeriod) {
+ this.gracePeriod = gracePeriod;
+ return this;
+ }
+
+ public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
+ this.numericObserver = numericObserver;
+ return this;
+ }
+
+ public AMRDefaultRuleProcessor build() {
+ return new AMRDefaultRuleProcessor(this);
+ }
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java
index 8ec118d..9d51075 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java
@@ -42,218 +42,219 @@
* Learner Processor (HAMR).
*
* @author Anh Thu Vu
- *
+ *
*/
public class AMRLearnerProcessor implements Processor {
-
- /**
+
+ /**
*
*/
- private static final long serialVersionUID = -2302897295090248013L;
-
- private static final Logger logger =
- LoggerFactory.getLogger(AMRLearnerProcessor.class);
+ private static final long serialVersionUID = -2302897295090248013L;
- private int processorId;
-
- private transient List<ActiveRule> ruleSet;
-
- private Stream outputStream;
-
- private double splitConfidence;
- private double tieThreshold;
- private int gracePeriod;
-
- private boolean noAnomalyDetection;
- private double multivariateAnomalyProbabilityThreshold;
- private double univariateAnomalyprobabilityThreshold;
- private int anomalyNumInstThreshold;
-
- public AMRLearnerProcessor(Builder builder) {
- this.splitConfidence = builder.splitConfidence;
- this.tieThreshold = builder.tieThreshold;
- this.gracePeriod = builder.gracePeriod;
-
- this.noAnomalyDetection = builder.noAnomalyDetection;
- this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
- this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
- this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
- }
+ private static final Logger logger =
+ LoggerFactory.getLogger(AMRLearnerProcessor.class);
- @Override
- public boolean process(ContentEvent event) {
- if (event instanceof AssignmentContentEvent) {
- AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event;
- trainRuleOnInstance(attrContentEvent.getRuleNumberID(),attrContentEvent.getInstance());
- }
- else if (event instanceof RuleContentEvent) {
- RuleContentEvent ruleContentEvent = (RuleContentEvent) event;
- if (!ruleContentEvent.isRemoving()) {
- addRule(ruleContentEvent.getRule());
- }
- }
-
- return false;
- }
-
- /*
- * Process input instances
- */
- private void trainRuleOnInstance(int ruleID, Instance instance) {
- //System.out.println("Processor:"+this.processorId+": Rule:"+ruleID+" -> Counter="+counter);
- Iterator<ActiveRule> ruleIterator= this.ruleSet.iterator();
- while (ruleIterator.hasNext()) {
- ActiveRule rule = ruleIterator.next();
- if (rule.getRuleNumberID() == ruleID) {
- // Check (again) for coverage
- if (rule.isCovering(instance) == true) {
- double error = rule.computeError(instance); //Use adaptive mode error
- boolean changeDetected = ((RuleActiveRegressionNode)rule.getLearningNode()).updateChangeDetection(error);
- if (changeDetected == true) {
- ruleIterator.remove();
-
- this.sendRemoveRuleEvent(ruleID);
- } else {
- rule.updateStatistics(instance);
- if (rule.getInstancesSeen() % this.gracePeriod == 0.0) {
- if (rule.tryToExpand(this.splitConfidence, this.tieThreshold) ) {
- rule.split();
-
- // expanded: update Aggregator with new/updated predicate
- this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(),
- (RuleActiveRegressionNode)rule.getLearningNode());
- }
-
- }
-
- }
- }
-
- return;
- }
- }
- }
-
- private boolean isAnomaly(Instance instance, LearningRule rule) {
- //AMRUles is equipped with anomaly detection. If on, compute the anomaly value.
- boolean isAnomaly = false;
- if (this.noAnomalyDetection == false){
- if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
- isAnomaly = rule.isAnomaly(instance,
- this.univariateAnomalyprobabilityThreshold,
- this.multivariateAnomalyProbabilityThreshold,
- this.anomalyNumInstThreshold);
- }
- }
- return isAnomaly;
- }
-
- private void sendRemoveRuleEvent(int ruleID) {
- RuleContentEvent rce = new RuleContentEvent(ruleID, null, true);
- this.outputStream.put(rce);
- }
-
- private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) {
- this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode)));
- }
-
- /*
- * Process control message (regarding adding or removing rules)
- */
- private boolean addRule(ActiveRule rule) {
- this.ruleSet.add(rule);
- return true;
- }
+ private int processorId;
- @Override
- public void onCreate(int id) {
- this.processorId = id;
- this.ruleSet = new LinkedList<ActiveRule>();
- }
+ private transient List<ActiveRule> ruleSet;
- @Override
- public Processor newProcessor(Processor p) {
- AMRLearnerProcessor oldProcessor = (AMRLearnerProcessor)p;
- AMRLearnerProcessor newProcessor =
- new AMRLearnerProcessor.Builder(oldProcessor).build();
-
- newProcessor.setOutputStream(oldProcessor.outputStream);
- return newProcessor;
- }
-
- /*
- * Builder
- */
- public static class Builder {
- private double splitConfidence;
- private double tieThreshold;
- private int gracePeriod;
-
- private boolean noAnomalyDetection;
- private double multivariateAnomalyProbabilityThreshold;
- private double univariateAnomalyprobabilityThreshold;
- private int anomalyNumInstThreshold;
-
- private Instances dataset;
-
- public Builder(Instances dataset){
- this.dataset = dataset;
- }
-
- public Builder(AMRLearnerProcessor processor) {
- this.splitConfidence = processor.splitConfidence;
- this.tieThreshold = processor.tieThreshold;
- this.gracePeriod = processor.gracePeriod;
- }
-
- public Builder splitConfidence(double splitConfidence) {
- this.splitConfidence = splitConfidence;
- return this;
- }
-
- public Builder tieThreshold(double tieThreshold) {
- this.tieThreshold = tieThreshold;
- return this;
- }
-
- public Builder gracePeriod(int gracePeriod) {
- this.gracePeriod = gracePeriod;
- return this;
- }
-
- public Builder noAnomalyDetection(boolean noAnomalyDetection) {
- this.noAnomalyDetection = noAnomalyDetection;
- return this;
- }
-
- public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
- this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
- return this;
- }
-
- public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
- this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
- return this;
- }
-
- public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
- this.anomalyNumInstThreshold = anomalyNumInstThreshold;
- return this;
- }
-
- public AMRLearnerProcessor build() {
- return new AMRLearnerProcessor(this);
- }
- }
-
- /*
- * Output stream
- */
- public void setOutputStream(Stream stream) {
- this.outputStream = stream;
- }
-
- public Stream getOutputStream() {
- return this.outputStream;
- }
+ private Stream outputStream;
+
+ private double splitConfidence;
+ private double tieThreshold;
+ private int gracePeriod;
+
+ private boolean noAnomalyDetection;
+ private double multivariateAnomalyProbabilityThreshold;
+ private double univariateAnomalyprobabilityThreshold;
+ private int anomalyNumInstThreshold;
+
+ public AMRLearnerProcessor(Builder builder) {
+ this.splitConfidence = builder.splitConfidence;
+ this.tieThreshold = builder.tieThreshold;
+ this.gracePeriod = builder.gracePeriod;
+
+ this.noAnomalyDetection = builder.noAnomalyDetection;
+ this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
+ this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
+ this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
+ }
+
+ @Override
+ public boolean process(ContentEvent event) {
+ if (event instanceof AssignmentContentEvent) {
+ AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event;
+ trainRuleOnInstance(attrContentEvent.getRuleNumberID(), attrContentEvent.getInstance());
+ }
+ else if (event instanceof RuleContentEvent) {
+ RuleContentEvent ruleContentEvent = (RuleContentEvent) event;
+ if (!ruleContentEvent.isRemoving()) {
+ addRule(ruleContentEvent.getRule());
+ }
+ }
+
+ return false;
+ }
+
+ /*
+ * Process input instances
+ */
+ private void trainRuleOnInstance(int ruleID, Instance instance) {
+ // System.out.println("Processor:"+this.processorId+": Rule:"+ruleID+" -> Counter="+counter);
+ Iterator<ActiveRule> ruleIterator = this.ruleSet.iterator();
+ while (ruleIterator.hasNext()) {
+ ActiveRule rule = ruleIterator.next();
+ if (rule.getRuleNumberID() == ruleID) {
+ // Check (again) for coverage
+ if (rule.isCovering(instance) == true) {
+ double error = rule.computeError(instance); // Use adaptive mode error
+ boolean changeDetected = ((RuleActiveRegressionNode) rule.getLearningNode()).updateChangeDetection(error);
+ if (changeDetected == true) {
+ ruleIterator.remove();
+
+ this.sendRemoveRuleEvent(ruleID);
+ } else {
+ rule.updateStatistics(instance);
+ if (rule.getInstancesSeen() % this.gracePeriod == 0.0) {
+ if (rule.tryToExpand(this.splitConfidence, this.tieThreshold)) {
+ rule.split();
+
+ // expanded: update Aggregator with new/updated predicate
+ this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(),
+ (RuleActiveRegressionNode) rule.getLearningNode());
+ }
+
+ }
+
+ }
+ }
+
+ return;
+ }
+ }
+ }
+
+ private boolean isAnomaly(Instance instance, LearningRule rule) {
+ // AMRUles is equipped with anomaly detection. If on, compute the anomaly
+ // value.
+ boolean isAnomaly = false;
+ if (this.noAnomalyDetection == false) {
+ if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
+ isAnomaly = rule.isAnomaly(instance,
+ this.univariateAnomalyprobabilityThreshold,
+ this.multivariateAnomalyProbabilityThreshold,
+ this.anomalyNumInstThreshold);
+ }
+ }
+ return isAnomaly;
+ }
+
+ private void sendRemoveRuleEvent(int ruleID) {
+ RuleContentEvent rce = new RuleContentEvent(ruleID, null, true);
+ this.outputStream.put(rce);
+ }
+
+ private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) {
+ this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode)));
+ }
+
+ /*
+ * Process control message (regarding adding or removing rules)
+ */
+ private boolean addRule(ActiveRule rule) {
+ this.ruleSet.add(rule);
+ return true;
+ }
+
+ @Override
+ public void onCreate(int id) {
+ this.processorId = id;
+ this.ruleSet = new LinkedList<ActiveRule>();
+ }
+
+ @Override
+ public Processor newProcessor(Processor p) {
+ AMRLearnerProcessor oldProcessor = (AMRLearnerProcessor) p;
+ AMRLearnerProcessor newProcessor =
+ new AMRLearnerProcessor.Builder(oldProcessor).build();
+
+ newProcessor.setOutputStream(oldProcessor.outputStream);
+ return newProcessor;
+ }
+
+ /*
+ * Builder
+ */
+ public static class Builder {
+ private double splitConfidence;
+ private double tieThreshold;
+ private int gracePeriod;
+
+ private boolean noAnomalyDetection;
+ private double multivariateAnomalyProbabilityThreshold;
+ private double univariateAnomalyprobabilityThreshold;
+ private int anomalyNumInstThreshold;
+
+ private Instances dataset;
+
+ public Builder(Instances dataset) {
+ this.dataset = dataset;
+ }
+
+ public Builder(AMRLearnerProcessor processor) {
+ this.splitConfidence = processor.splitConfidence;
+ this.tieThreshold = processor.tieThreshold;
+ this.gracePeriod = processor.gracePeriod;
+ }
+
+ public Builder splitConfidence(double splitConfidence) {
+ this.splitConfidence = splitConfidence;
+ return this;
+ }
+
+ public Builder tieThreshold(double tieThreshold) {
+ this.tieThreshold = tieThreshold;
+ return this;
+ }
+
+ public Builder gracePeriod(int gracePeriod) {
+ this.gracePeriod = gracePeriod;
+ return this;
+ }
+
+ public Builder noAnomalyDetection(boolean noAnomalyDetection) {
+ this.noAnomalyDetection = noAnomalyDetection;
+ return this;
+ }
+
+ public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
+ this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
+ return this;
+ }
+
+ public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
+ this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
+ return this;
+ }
+
+ public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
+ this.anomalyNumInstThreshold = anomalyNumInstThreshold;
+ return this;
+ }
+
+ public AMRLearnerProcessor build() {
+ return new AMRLearnerProcessor(this);
+ }
+ }
+
+ /*
+ * Output stream
+ */
+ public void setOutputStream(Stream stream) {
+ this.outputStream = stream;
+ }
+
+ public Stream getOutputStream() {
+ return this.outputStream;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java
index 38a0be1..88bf375 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java
@@ -40,323 +40,333 @@
/**
* Model Aggregator Processor (HAMR).
+ *
* @author Anh Thu Vu
- *
+ *
*/
public class AMRRuleSetProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -6544096255649379334L;
- private static final Logger logger = LoggerFactory.getLogger(AMRRuleSetProcessor.class);
+ private static final long serialVersionUID = -6544096255649379334L;
+ private static final Logger logger = LoggerFactory.getLogger(AMRRuleSetProcessor.class);
- private int processorId;
+ private int processorId;
- // Rules & default rule
- protected transient List<PassiveRule> ruleSet;
+ // Rules & default rule
+ protected transient List<PassiveRule> ruleSet;
- // SAMOA Stream
- private Stream statisticsStream;
- private Stream resultStream;
- private Stream defaultRuleStream;
+ // SAMOA Stream
+ private Stream statisticsStream;
+ private Stream resultStream;
+ private Stream defaultRuleStream;
- // Options
- protected boolean noAnomalyDetection;
- protected double multivariateAnomalyProbabilityThreshold;
- protected double univariateAnomalyprobabilityThreshold;
- protected int anomalyNumInstThreshold;
+ // Options
+ protected boolean noAnomalyDetection;
+ protected double multivariateAnomalyProbabilityThreshold;
+ protected double univariateAnomalyprobabilityThreshold;
+ protected int anomalyNumInstThreshold;
- protected boolean unorderedRules;
-
- protected int voteType;
-
- /*
- * Constructor
- */
- public AMRRuleSetProcessor (Builder builder) {
-
- this.noAnomalyDetection = builder.noAnomalyDetection;
- this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
- this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
- this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
- this.unorderedRules = builder.unorderedRules;
-
- this.voteType = builder.voteType;
- }
- /* (non-Javadoc)
- * @see com.yahoo.labs.samoa.core.Processor#process(com.yahoo.labs.samoa.core.ContentEvent)
- */
- @Override
- public boolean process(ContentEvent event) {
- if (event instanceof InstanceContentEvent) {
- this.processInstanceEvent((InstanceContentEvent) event);
- }
- else if (event instanceof PredicateContentEvent) {
- PredicateContentEvent pce = (PredicateContentEvent) event;
- if (pce.getRuleSplitNode() == null) {
- this.updateLearningNode(pce);
- }
- else {
- this.updateRuleSplitNode(pce);
- }
- }
- else if (event instanceof RuleContentEvent) {
- RuleContentEvent rce = (RuleContentEvent) event;
- if (rce.isRemoving()) {
- this.removeRule(rce.getRuleNumberID());
- }
- else {
- addRule(rce.getRule());
- }
- }
- return true;
- }
-
- private void processInstanceEvent(InstanceContentEvent instanceEvent) {
- Instance instance = instanceEvent.getInstance();
- boolean predictionCovered = false;
- boolean trainingCovered = false;
- boolean continuePrediction = instanceEvent.isTesting();
- boolean continueTraining = instanceEvent.isTraining();
-
- ErrorWeightedVote errorWeightedVote = newErrorWeightedVote();
- for (PassiveRule aRuleSet : this.ruleSet) {
- if (!continuePrediction && !continueTraining)
- break;
+ protected boolean unorderedRules;
- if (aRuleSet.isCovering(instance)) {
- predictionCovered = true;
+ protected int voteType;
- if (continuePrediction) {
- double[] vote = aRuleSet.getPrediction(instance);
- double error = aRuleSet.getCurrentError();
- errorWeightedVote.addVote(vote, error);
- if (!this.unorderedRules) continuePrediction = false;
- }
+ /*
+ * Constructor
+ */
+ public AMRRuleSetProcessor(Builder builder) {
- if (continueTraining) {
- if (!isAnomaly(instance, aRuleSet)) {
- trainingCovered = true;
- aRuleSet.updateStatistics(instance);
+ this.noAnomalyDetection = builder.noAnomalyDetection;
+ this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
+ this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
+ this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
+ this.unorderedRules = builder.unorderedRules;
- // Send instance to statistics PIs
- sendInstanceToRule(instance, aRuleSet.getRuleNumberID());
+ this.voteType = builder.voteType;
+ }
- if (!this.unorderedRules) continueTraining = false;
- }
- }
- }
- }
-
- if (predictionCovered) {
- // Combined prediction
- ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent);
- resultStream.put(rce);
- }
-
- boolean defaultPrediction = instanceEvent.isTesting() && !predictionCovered;
- boolean defaultTraining = instanceEvent.isTraining() && !trainingCovered;
- if (defaultPrediction || defaultTraining) {
- instanceEvent.setTesting(defaultPrediction);
- instanceEvent.setTraining(defaultTraining);
- this.defaultRuleStream.put(instanceEvent);
- }
- }
-
- private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){
- ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent());
- rce.setClassifierIndex(this.processorId);
- rce.setEvaluationIndex(inEvent.getEvaluationIndex());
- return rce;
- }
+ /*
+ * (non-Javadoc)
+ *
+ * @see com.yahoo.labs.samoa.core.Processor#process(com.yahoo.labs.samoa.core.
+ * ContentEvent)
+ */
+ @Override
+ public boolean process(ContentEvent event) {
+ if (event instanceof InstanceContentEvent) {
+ this.processInstanceEvent((InstanceContentEvent) event);
+ }
+ else if (event instanceof PredicateContentEvent) {
+ PredicateContentEvent pce = (PredicateContentEvent) event;
+ if (pce.getRuleSplitNode() == null) {
+ this.updateLearningNode(pce);
+ }
+ else {
+ this.updateRuleSplitNode(pce);
+ }
+ }
+ else if (event instanceof RuleContentEvent) {
+ RuleContentEvent rce = (RuleContentEvent) event;
+ if (rce.isRemoving()) {
+ this.removeRule(rce.getRuleNumberID());
+ }
+ else {
+ addRule(rce.getRule());
+ }
+ }
+ return true;
+ }
- public ErrorWeightedVote newErrorWeightedVote() {
- // TODO: do a reset instead of init a new object
- if (voteType == 1)
- return new UniformWeightedVote();
- return new InverseErrorWeightedVote();
- }
+ private void processInstanceEvent(InstanceContentEvent instanceEvent) {
+ Instance instance = instanceEvent.getInstance();
+ boolean predictionCovered = false;
+ boolean trainingCovered = false;
+ boolean continuePrediction = instanceEvent.isTesting();
+ boolean continueTraining = instanceEvent.isTraining();
- /**
- * Method to verify if the instance is an anomaly.
- * @param instance
- * @param rule
- * @return
- */
- private boolean isAnomaly(Instance instance, LearningRule rule) {
- //AMRUles is equipped with anomaly detection. If on, compute the anomaly value.
- boolean isAnomaly = false;
- if (!this.noAnomalyDetection){
- if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
- isAnomaly = rule.isAnomaly(instance,
- this.univariateAnomalyprobabilityThreshold,
- this.multivariateAnomalyProbabilityThreshold,
- this.anomalyNumInstThreshold);
- }
- }
- return isAnomaly;
- }
-
- /*
- * Add predicate/RuleSplitNode for a rule
- */
- private void updateRuleSplitNode(PredicateContentEvent pce) {
- int ruleID = pce.getRuleNumberID();
- for (PassiveRule rule:ruleSet) {
- if (rule.getRuleNumberID() == ruleID) {
- rule.nodeListAdd(pce.getRuleSplitNode());
- rule.setLearningNode(pce.getLearningNode());
- }
- }
- }
-
- private void updateLearningNode(PredicateContentEvent pce) {
- int ruleID = pce.getRuleNumberID();
- for (PassiveRule rule:ruleSet) {
- if (rule.getRuleNumberID() == ruleID) {
- rule.setLearningNode(pce.getLearningNode());
- }
- }
- }
-
- /*
- * Add new rule/Remove rule
- */
- private boolean addRule(ActiveRule rule) {
- this.ruleSet.add(new PassiveRule(rule));
- return true;
- }
-
- private void removeRule(int ruleID) {
- for (PassiveRule rule:ruleSet) {
- if (rule.getRuleNumberID() == ruleID) {
- ruleSet.remove(rule);
- break;
- }
- }
- }
+ ErrorWeightedVote errorWeightedVote = newErrorWeightedVote();
+ for (PassiveRule aRuleSet : this.ruleSet) {
+ if (!continuePrediction && !continueTraining)
+ break;
- @Override
- public void onCreate(int id) {
- this.processorId = id;
- this.ruleSet = new LinkedList<PassiveRule>();
-
- }
-
- /*
- * Clone processor
- */
- @Override
- public Processor newProcessor(Processor p) {
- AMRRuleSetProcessor oldProcessor = (AMRRuleSetProcessor) p;
- Builder builder = new Builder(oldProcessor);
- AMRRuleSetProcessor newProcessor = builder.build();
- newProcessor.resultStream = oldProcessor.resultStream;
- newProcessor.statisticsStream = oldProcessor.statisticsStream;
- newProcessor.defaultRuleStream = oldProcessor.defaultRuleStream;
- return newProcessor;
- }
-
- /*
- * Send events
- */
- private void sendInstanceToRule(Instance instance, int ruleID) {
- AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance);
- this.statisticsStream.put(ace);
- }
-
- /*
- * Output streams
- */
- public void setStatisticsStream(Stream statisticsStream) {
- this.statisticsStream = statisticsStream;
- }
-
- public Stream getStatisticsStream() {
- return this.statisticsStream;
- }
-
- public void setResultStream(Stream resultStream) {
- this.resultStream = resultStream;
- }
-
- public Stream getResultStream() {
- return this.resultStream;
- }
-
- public Stream getDefaultRuleStream() {
- return this.defaultRuleStream;
- }
-
- public void setDefaultRuleStream(Stream defaultRuleStream) {
- this.defaultRuleStream = defaultRuleStream;
- }
-
- /*
- * Builder
- */
- public static class Builder {
- private boolean noAnomalyDetection;
- private double multivariateAnomalyProbabilityThreshold;
- private double univariateAnomalyprobabilityThreshold;
- private int anomalyNumInstThreshold;
-
- private boolean unorderedRules;
-
-// private FIMTDDNumericAttributeClassLimitObserver numericObserver;
- private int voteType;
-
- private Instances dataset;
-
- public Builder(Instances dataset){
- this.dataset = dataset;
- }
-
- public Builder(AMRRuleSetProcessor processor) {
-
- this.noAnomalyDetection = processor.noAnomalyDetection;
- this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold;
- this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold;
- this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold;
- this.unorderedRules = processor.unorderedRules;
-
- this.voteType = processor.voteType;
- }
-
- public Builder noAnomalyDetection(boolean noAnomalyDetection) {
- this.noAnomalyDetection = noAnomalyDetection;
- return this;
- }
-
- public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
- this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
- return this;
- }
-
- public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
- this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
- return this;
- }
-
- public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
- this.anomalyNumInstThreshold = anomalyNumInstThreshold;
- return this;
- }
-
- public Builder unorderedRules(boolean unorderedRules) {
- this.unorderedRules = unorderedRules;
- return this;
- }
-
- public Builder voteType(int voteType) {
- this.voteType = voteType;
- return this;
- }
-
- public AMRRuleSetProcessor build() {
- return new AMRRuleSetProcessor(this);
- }
- }
+ if (aRuleSet.isCovering(instance)) {
+ predictionCovered = true;
+
+ if (continuePrediction) {
+ double[] vote = aRuleSet.getPrediction(instance);
+ double error = aRuleSet.getCurrentError();
+ errorWeightedVote.addVote(vote, error);
+ if (!this.unorderedRules)
+ continuePrediction = false;
+ }
+
+ if (continueTraining) {
+ if (!isAnomaly(instance, aRuleSet)) {
+ trainingCovered = true;
+ aRuleSet.updateStatistics(instance);
+
+ // Send instance to statistics PIs
+ sendInstanceToRule(instance, aRuleSet.getRuleNumberID());
+
+ if (!this.unorderedRules)
+ continueTraining = false;
+ }
+ }
+ }
+ }
+
+ if (predictionCovered) {
+ // Combined prediction
+ ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent);
+ resultStream.put(rce);
+ }
+
+ boolean defaultPrediction = instanceEvent.isTesting() && !predictionCovered;
+ boolean defaultTraining = instanceEvent.isTraining() && !trainingCovered;
+ if (defaultPrediction || defaultTraining) {
+ instanceEvent.setTesting(defaultPrediction);
+ instanceEvent.setTraining(defaultTraining);
+ this.defaultRuleStream.put(instanceEvent);
+ }
+ }
+
+ private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
+ ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
+ inEvent.getClassId(), prediction, inEvent.isLastEvent());
+ rce.setClassifierIndex(this.processorId);
+ rce.setEvaluationIndex(inEvent.getEvaluationIndex());
+ return rce;
+ }
+
+ public ErrorWeightedVote newErrorWeightedVote() {
+ // TODO: do a reset instead of init a new object
+ if (voteType == 1)
+ return new UniformWeightedVote();
+ return new InverseErrorWeightedVote();
+ }
+
+ /**
+ * Method to verify if the instance is an anomaly.
+ *
+ * @param instance
+ * @param rule
+ * @return
+ */
+ private boolean isAnomaly(Instance instance, LearningRule rule) {
+ // AMRUles is equipped with anomaly detection. If on, compute the anomaly
+ // value.
+ boolean isAnomaly = false;
+ if (!this.noAnomalyDetection) {
+ if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
+ isAnomaly = rule.isAnomaly(instance,
+ this.univariateAnomalyprobabilityThreshold,
+ this.multivariateAnomalyProbabilityThreshold,
+ this.anomalyNumInstThreshold);
+ }
+ }
+ return isAnomaly;
+ }
+
+ /*
+ * Add predicate/RuleSplitNode for a rule
+ */
+ private void updateRuleSplitNode(PredicateContentEvent pce) {
+ int ruleID = pce.getRuleNumberID();
+ for (PassiveRule rule : ruleSet) {
+ if (rule.getRuleNumberID() == ruleID) {
+ rule.nodeListAdd(pce.getRuleSplitNode());
+ rule.setLearningNode(pce.getLearningNode());
+ }
+ }
+ }
+
+ private void updateLearningNode(PredicateContentEvent pce) {
+ int ruleID = pce.getRuleNumberID();
+ for (PassiveRule rule : ruleSet) {
+ if (rule.getRuleNumberID() == ruleID) {
+ rule.setLearningNode(pce.getLearningNode());
+ }
+ }
+ }
+
+ /*
+ * Add new rule/Remove rule
+ */
+ private boolean addRule(ActiveRule rule) {
+ this.ruleSet.add(new PassiveRule(rule));
+ return true;
+ }
+
+ private void removeRule(int ruleID) {
+ for (PassiveRule rule : ruleSet) {
+ if (rule.getRuleNumberID() == ruleID) {
+ ruleSet.remove(rule);
+ break;
+ }
+ }
+ }
+
+ @Override
+ public void onCreate(int id) {
+ this.processorId = id;
+ this.ruleSet = new LinkedList<PassiveRule>();
+
+ }
+
+ /*
+ * Clone processor
+ */
+ @Override
+ public Processor newProcessor(Processor p) {
+ AMRRuleSetProcessor oldProcessor = (AMRRuleSetProcessor) p;
+ Builder builder = new Builder(oldProcessor);
+ AMRRuleSetProcessor newProcessor = builder.build();
+ newProcessor.resultStream = oldProcessor.resultStream;
+ newProcessor.statisticsStream = oldProcessor.statisticsStream;
+ newProcessor.defaultRuleStream = oldProcessor.defaultRuleStream;
+ return newProcessor;
+ }
+
+ /*
+ * Send events
+ */
+ private void sendInstanceToRule(Instance instance, int ruleID) {
+ AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance);
+ this.statisticsStream.put(ace);
+ }
+
+ /*
+ * Output streams
+ */
+ public void setStatisticsStream(Stream statisticsStream) {
+ this.statisticsStream = statisticsStream;
+ }
+
+ public Stream getStatisticsStream() {
+ return this.statisticsStream;
+ }
+
+ public void setResultStream(Stream resultStream) {
+ this.resultStream = resultStream;
+ }
+
+ public Stream getResultStream() {
+ return this.resultStream;
+ }
+
+ public Stream getDefaultRuleStream() {
+ return this.defaultRuleStream;
+ }
+
+ public void setDefaultRuleStream(Stream defaultRuleStream) {
+ this.defaultRuleStream = defaultRuleStream;
+ }
+
+ /*
+ * Builder
+ */
+ public static class Builder {
+ private boolean noAnomalyDetection;
+ private double multivariateAnomalyProbabilityThreshold;
+ private double univariateAnomalyprobabilityThreshold;
+ private int anomalyNumInstThreshold;
+
+ private boolean unorderedRules;
+
+ // private FIMTDDNumericAttributeClassLimitObserver numericObserver;
+ private int voteType;
+
+ private Instances dataset;
+
+ public Builder(Instances dataset) {
+ this.dataset = dataset;
+ }
+
+ public Builder(AMRRuleSetProcessor processor) {
+
+ this.noAnomalyDetection = processor.noAnomalyDetection;
+ this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold;
+ this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold;
+ this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold;
+ this.unorderedRules = processor.unorderedRules;
+
+ this.voteType = processor.voteType;
+ }
+
+ public Builder noAnomalyDetection(boolean noAnomalyDetection) {
+ this.noAnomalyDetection = noAnomalyDetection;
+ return this;
+ }
+
+ public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
+ this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
+ return this;
+ }
+
+ public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
+ this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
+ return this;
+ }
+
+ public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
+ this.anomalyNumInstThreshold = anomalyNumInstThreshold;
+ return this;
+ }
+
+ public Builder unorderedRules(boolean unorderedRules) {
+ this.unorderedRules = unorderedRules;
+ return this;
+ }
+
+ public Builder voteType(int voteType) {
+ this.voteType = voteType;
+ return this;
+ }
+
+ public AMRRuleSetProcessor build() {
+ return new AMRRuleSetProcessor(this);
+ }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java
index debe912..5cb4f01 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java
@@ -48,478 +48,487 @@
* Model Aggregator Processor (VAMR).
*
* @author Anh Thu Vu
- *
+ *
*/
public class AMRulesAggregatorProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 6303385725332704251L;
+ private static final long serialVersionUID = 6303385725332704251L;
- private static final Logger logger =
- LoggerFactory.getLogger(AMRulesAggregatorProcessor.class);
-
- private int processorId;
+ private static final Logger logger =
+ LoggerFactory.getLogger(AMRulesAggregatorProcessor.class);
- // Rules & default rule
- protected transient List<PassiveRule> ruleSet;
- protected transient ActiveRule defaultRule;
- protected transient int ruleNumberID;
- protected transient double[] statistics;
+ private int processorId;
- // SAMOA Stream
- private Stream statisticsStream;
- private Stream resultStream;
+ // Rules & default rule
+ protected transient List<PassiveRule> ruleSet;
+ protected transient ActiveRule defaultRule;
+ protected transient int ruleNumberID;
+ protected transient double[] statistics;
- // Options
- protected int pageHinckleyThreshold;
- protected double pageHinckleyAlpha;
- protected boolean driftDetection;
- protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
- protected boolean constantLearningRatioDecay;
- protected double learningRatio;
+ // SAMOA Stream
+ private Stream statisticsStream;
+ private Stream resultStream;
- protected double splitConfidence;
- protected double tieThreshold;
- protected int gracePeriod;
+ // Options
+ protected int pageHinckleyThreshold;
+ protected double pageHinckleyAlpha;
+ protected boolean driftDetection;
+ protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
+ protected boolean constantLearningRatioDecay;
+ protected double learningRatio;
- protected boolean noAnomalyDetection;
- protected double multivariateAnomalyProbabilityThreshold;
- protected double univariateAnomalyprobabilityThreshold;
- protected int anomalyNumInstThreshold;
+ protected double splitConfidence;
+ protected double tieThreshold;
+ protected int gracePeriod;
- protected boolean unorderedRules;
+ protected boolean noAnomalyDetection;
+ protected double multivariateAnomalyProbabilityThreshold;
+ protected double univariateAnomalyprobabilityThreshold;
+ protected int anomalyNumInstThreshold;
- protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
- protected int voteType;
-
- /*
- * Constructor
- */
- public AMRulesAggregatorProcessor (Builder builder) {
- this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
- this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
- this.driftDetection = builder.driftDetection;
- this.predictionFunction = builder.predictionFunction;
- this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
- this.learningRatio = builder.learningRatio;
- this.splitConfidence = builder.splitConfidence;
- this.tieThreshold = builder.tieThreshold;
- this.gracePeriod = builder.gracePeriod;
-
- this.noAnomalyDetection = builder.noAnomalyDetection;
- this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
- this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
- this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
- this.unorderedRules = builder.unorderedRules;
-
- this.numericObserver = builder.numericObserver;
- this.voteType = builder.voteType;
- }
-
- /*
- * Process
- */
- @Override
- public boolean process(ContentEvent event) {
- if (event instanceof InstanceContentEvent) {
- InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
- this.processInstanceEvent(instanceEvent);
- }
- else if (event instanceof PredicateContentEvent) {
- this.updateRuleSplitNode((PredicateContentEvent) event);
- }
- else if (event instanceof RuleContentEvent) {
- RuleContentEvent rce = (RuleContentEvent) event;
- if (rce.isRemoving()) {
- this.removeRule(rce.getRuleNumberID());
- }
- }
-
- return true;
- }
-
- // Merge predict and train so we only check for covering rules one time
- private void processInstanceEvent(InstanceContentEvent instanceEvent) {
- Instance instance = instanceEvent.getInstance();
- boolean predictionCovered = false;
- boolean trainingCovered = false;
- boolean continuePrediction = instanceEvent.isTesting();
- boolean continueTraining = instanceEvent.isTraining();
-
- ErrorWeightedVote errorWeightedVote = newErrorWeightedVote();
- Iterator<PassiveRule> ruleIterator= this.ruleSet.iterator();
- while (ruleIterator.hasNext()) {
- if (!continuePrediction && !continueTraining)
- break;
-
- PassiveRule rule = ruleIterator.next();
-
- if (rule.isCovering(instance) == true){
- predictionCovered = true;
+ protected boolean unorderedRules;
- if (continuePrediction) {
- double [] vote=rule.getPrediction(instance);
- double error= rule.getCurrentError();
- errorWeightedVote.addVote(vote,error);
- if (!this.unorderedRules) continuePrediction = false;
- }
-
- if (continueTraining) {
- if (!isAnomaly(instance, rule)) {
- trainingCovered = true;
- rule.updateStatistics(instance);
- // Send instance to statistics PIs
- sendInstanceToRule(instance, rule.getRuleNumberID());
-
- if (!this.unorderedRules) continueTraining = false;
- }
- }
- }
- }
-
- if (predictionCovered) {
- // Combined prediction
- ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent);
- resultStream.put(rce);
- }
- else if (instanceEvent.isTesting()) {
- // predict with default rule
- double [] vote=defaultRule.getPrediction(instance);
- ResultContentEvent rce = newResultContentEvent(vote, instanceEvent);
- resultStream.put(rce);
- }
-
- if (!trainingCovered && instanceEvent.isTraining()) {
- // train default rule with this instance
- defaultRule.updateStatistics(instance);
- if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
- if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
- ActiveRule newDefaultRule=newRule(defaultRule.getRuleNumberID(),(RuleActiveRegressionNode)defaultRule.getLearningNode(),
- ((RuleActiveRegressionNode)defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); //other branch
- defaultRule.split();
- defaultRule.setRuleNumberID(++ruleNumberID);
- this.ruleSet.add(new PassiveRule(this.defaultRule));
- // send to statistics PI
- sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule);
- defaultRule=newDefaultRule;
- }
- }
- }
- }
-
- /**
- * Helper method to generate new ResultContentEvent based on an instance and
- * its prediction result.
- * @param prediction The predicted class label from the decision tree model.
- * @param inEvent The associated instance content event
- * @return ResultContentEvent to be sent into Evaluator PI or other destination PI.
- */
- private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){
- ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent());
- rce.setClassifierIndex(this.processorId);
- rce.setEvaluationIndex(inEvent.getEvaluationIndex());
- return rce;
- }
+ protected FIMTDDNumericAttributeClassLimitObserver numericObserver;
+ protected int voteType;
- public ErrorWeightedVote newErrorWeightedVote() {
- if (voteType == 1)
- return new UniformWeightedVote();
- return new InverseErrorWeightedVote();
- }
+ /*
+ * Constructor
+ */
+ public AMRulesAggregatorProcessor(Builder builder) {
+ this.pageHinckleyThreshold = builder.pageHinckleyThreshold;
+ this.pageHinckleyAlpha = builder.pageHinckleyAlpha;
+ this.driftDetection = builder.driftDetection;
+ this.predictionFunction = builder.predictionFunction;
+ this.constantLearningRatioDecay = builder.constantLearningRatioDecay;
+ this.learningRatio = builder.learningRatio;
+ this.splitConfidence = builder.splitConfidence;
+ this.tieThreshold = builder.tieThreshold;
+ this.gracePeriod = builder.gracePeriod;
- /**
- * Method to verify if the instance is an anomaly.
- * @param instance
- * @param rule
- * @return
- */
- private boolean isAnomaly(Instance instance, LearningRule rule) {
- //AMRUles is equipped with anomaly detection. If on, compute the anomaly value.
- boolean isAnomaly = false;
- if (this.noAnomalyDetection == false){
- if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
- isAnomaly = rule.isAnomaly(instance,
- this.univariateAnomalyprobabilityThreshold,
- this.multivariateAnomalyProbabilityThreshold,
- this.anomalyNumInstThreshold);
- }
- }
- return isAnomaly;
- }
-
- /*
- * Create new rules
- */
- private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
- ActiveRule r=newRule(ID);
+ this.noAnomalyDetection = builder.noAnomalyDetection;
+ this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold;
+ this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold;
+ this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold;
+ this.unorderedRules = builder.unorderedRules;
- if (node!=null)
- {
- if(node.getPerceptron()!=null)
- {
- r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
- r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
- }
- if (statistics==null)
- {
- double mean;
- if(node.getNodeStatistics().getValue(0)>0){
- mean=node.getNodeStatistics().getValue(1)/node.getNodeStatistics().getValue(0);
- r.getLearningNode().getTargetMean().reset(mean, 1);
- }
- }
- }
- if (statistics!=null && ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean()!=null)
- {
- double mean;
- if(statistics[0]>0){
- mean=statistics[1]/statistics[0];
- ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean().reset(mean, (long)statistics[0]);
- }
- }
- return r;
- }
+ this.numericObserver = builder.numericObserver;
+ this.voteType = builder.voteType;
+ }
- private ActiveRule newRule(int ID) {
- ActiveRule r=new ActiveRule.Builder().
- threshold(this.pageHinckleyThreshold).
- alpha(this.pageHinckleyAlpha).
- changeDetection(this.driftDetection).
- predictionFunction(this.predictionFunction).
- statistics(new double[3]).
- learningRatio(this.learningRatio).
- numericObserver(numericObserver).
- id(ID).build();
- return r;
- }
-
- /*
- * Add predicate/RuleSplitNode for a rule
- */
- private void updateRuleSplitNode(PredicateContentEvent pce) {
- int ruleID = pce.getRuleNumberID();
- for (PassiveRule rule:ruleSet) {
- if (rule.getRuleNumberID() == ruleID) {
- if (pce.getRuleSplitNode() != null)
- rule.nodeListAdd(pce.getRuleSplitNode());
- if (pce.getLearningNode() != null)
- rule.setLearningNode(pce.getLearningNode());
- }
- }
- }
-
- /*
- * Remove rule
- */
- private void removeRule(int ruleID) {
- for (PassiveRule rule:ruleSet) {
- if (rule.getRuleNumberID() == ruleID) {
- ruleSet.remove(rule);
- break;
- }
- }
- }
-
- @Override
- public void onCreate(int id) {
- this.processorId = id;
- this.statistics= new double[]{0.0,0,0};
- this.ruleNumberID=0;
- this.defaultRule = newRule(++this.ruleNumberID);
-
- this.ruleSet = new LinkedList<PassiveRule>();
- }
-
- /*
- * Clone processor
- */
- @Override
- public Processor newProcessor(Processor p) {
- AMRulesAggregatorProcessor oldProcessor = (AMRulesAggregatorProcessor) p;
- Builder builder = new Builder(oldProcessor);
- AMRulesAggregatorProcessor newProcessor = builder.build();
- newProcessor.resultStream = oldProcessor.resultStream;
- newProcessor.statisticsStream = oldProcessor.statisticsStream;
- return newProcessor;
- }
-
- /*
- * Send events
- */
- private void sendInstanceToRule(Instance instance, int ruleID) {
- AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance);
- this.statisticsStream.put(ace);
- }
-
-
-
- private void sendAddRuleEvent(int ruleID, ActiveRule rule) {
- RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false);
- this.statisticsStream.put(rce);
- }
-
- /*
- * Output streams
- */
- public void setStatisticsStream(Stream statisticsStream) {
- this.statisticsStream = statisticsStream;
- }
-
- public Stream getStatisticsStream() {
- return this.statisticsStream;
- }
-
- public void setResultStream(Stream resultStream) {
- this.resultStream = resultStream;
- }
-
- public Stream getResultStream() {
- return this.resultStream;
- }
-
- /*
- * Others
- */
- public boolean isRandomizable() {
- return true;
+ /*
+ * Process
+ */
+ @Override
+ public boolean process(ContentEvent event) {
+ if (event instanceof InstanceContentEvent) {
+ InstanceContentEvent instanceEvent = (InstanceContentEvent) event;
+ this.processInstanceEvent(instanceEvent);
}
-
- /*
- * Builder
- */
- public static class Builder {
- private int pageHinckleyThreshold;
- private double pageHinckleyAlpha;
- private boolean driftDetection;
- private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
- private boolean constantLearningRatioDecay;
- private double learningRatio;
- private double splitConfidence;
- private double tieThreshold;
- private int gracePeriod;
-
- private boolean noAnomalyDetection;
- private double multivariateAnomalyProbabilityThreshold;
- private double univariateAnomalyprobabilityThreshold;
- private int anomalyNumInstThreshold;
-
- private boolean unorderedRules;
-
- private FIMTDDNumericAttributeClassLimitObserver numericObserver;
- private int voteType;
-
- private Instances dataset;
-
- public Builder(Instances dataset){
- this.dataset = dataset;
- }
-
- public Builder(AMRulesAggregatorProcessor processor) {
- this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
- this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
- this.driftDetection = processor.driftDetection;
- this.predictionFunction = processor.predictionFunction;
- this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
- this.learningRatio = processor.learningRatio;
- this.splitConfidence = processor.splitConfidence;
- this.tieThreshold = processor.tieThreshold;
- this.gracePeriod = processor.gracePeriod;
-
- this.noAnomalyDetection = processor.noAnomalyDetection;
- this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold;
- this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold;
- this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold;
- this.unorderedRules = processor.unorderedRules;
-
- this.numericObserver = processor.numericObserver;
- this.voteType = processor.voteType;
- }
-
- public Builder threshold(int threshold) {
- this.pageHinckleyThreshold = threshold;
- return this;
- }
-
- public Builder alpha(double alpha) {
- this.pageHinckleyAlpha = alpha;
- return this;
- }
-
- public Builder changeDetection(boolean changeDetection) {
- this.driftDetection = changeDetection;
- return this;
- }
-
- public Builder predictionFunction(int predictionFunction) {
- this.predictionFunction = predictionFunction;
- return this;
- }
-
- public Builder constantLearningRatioDecay(boolean constantDecay) {
- this.constantLearningRatioDecay = constantDecay;
- return this;
- }
-
- public Builder learningRatio(double learningRatio) {
- this.learningRatio = learningRatio;
- return this;
- }
-
- public Builder splitConfidence(double splitConfidence) {
- this.splitConfidence = splitConfidence;
- return this;
- }
-
- public Builder tieThreshold(double tieThreshold) {
- this.tieThreshold = tieThreshold;
- return this;
- }
-
- public Builder gracePeriod(int gracePeriod) {
- this.gracePeriod = gracePeriod;
- return this;
- }
-
- public Builder noAnomalyDetection(boolean noAnomalyDetection) {
- this.noAnomalyDetection = noAnomalyDetection;
- return this;
- }
-
- public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
- this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
- return this;
- }
-
- public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
- this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
- return this;
- }
-
- public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
- this.anomalyNumInstThreshold = anomalyNumInstThreshold;
- return this;
- }
-
- public Builder unorderedRules(boolean unorderedRules) {
- this.unorderedRules = unorderedRules;
- return this;
- }
-
- public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
- this.numericObserver = numericObserver;
- return this;
- }
-
- public Builder voteType(int voteType) {
- this.voteType = voteType;
- return this;
- }
-
- public AMRulesAggregatorProcessor build() {
- return new AMRulesAggregatorProcessor(this);
- }
- }
+ else if (event instanceof PredicateContentEvent) {
+ this.updateRuleSplitNode((PredicateContentEvent) event);
+ }
+ else if (event instanceof RuleContentEvent) {
+ RuleContentEvent rce = (RuleContentEvent) event;
+ if (rce.isRemoving()) {
+ this.removeRule(rce.getRuleNumberID());
+ }
+ }
+
+ return true;
+ }
+
+ // Merge predict and train so we only check for covering rules one time
+ private void processInstanceEvent(InstanceContentEvent instanceEvent) {
+ Instance instance = instanceEvent.getInstance();
+ boolean predictionCovered = false;
+ boolean trainingCovered = false;
+ boolean continuePrediction = instanceEvent.isTesting();
+ boolean continueTraining = instanceEvent.isTraining();
+
+ ErrorWeightedVote errorWeightedVote = newErrorWeightedVote();
+ Iterator<PassiveRule> ruleIterator = this.ruleSet.iterator();
+ while (ruleIterator.hasNext()) {
+ if (!continuePrediction && !continueTraining)
+ break;
+
+ PassiveRule rule = ruleIterator.next();
+
+ if (rule.isCovering(instance) == true) {
+ predictionCovered = true;
+
+ if (continuePrediction) {
+ double[] vote = rule.getPrediction(instance);
+ double error = rule.getCurrentError();
+ errorWeightedVote.addVote(vote, error);
+ if (!this.unorderedRules)
+ continuePrediction = false;
+ }
+
+ if (continueTraining) {
+ if (!isAnomaly(instance, rule)) {
+ trainingCovered = true;
+ rule.updateStatistics(instance);
+ // Send instance to statistics PIs
+ sendInstanceToRule(instance, rule.getRuleNumberID());
+
+ if (!this.unorderedRules)
+ continueTraining = false;
+ }
+ }
+ }
+ }
+
+ if (predictionCovered) {
+ // Combined prediction
+ ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent);
+ resultStream.put(rce);
+ }
+ else if (instanceEvent.isTesting()) {
+ // predict with default rule
+ double[] vote = defaultRule.getPrediction(instance);
+ ResultContentEvent rce = newResultContentEvent(vote, instanceEvent);
+ resultStream.put(rce);
+ }
+
+ if (!trainingCovered && instanceEvent.isTraining()) {
+ // train default rule with this instance
+ defaultRule.updateStatistics(instance);
+ if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) {
+ if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) {
+ ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(),
+ (RuleActiveRegressionNode) defaultRule.getLearningNode(),
+ ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other
+ // branch
+ defaultRule.split();
+ defaultRule.setRuleNumberID(++ruleNumberID);
+ this.ruleSet.add(new PassiveRule(this.defaultRule));
+ // send to statistics PI
+ sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule);
+ defaultRule = newDefaultRule;
+ }
+ }
+ }
+ }
+
+ /**
+ * Helper method to generate new ResultContentEvent based on an instance and
+ * its prediction result.
+ *
+ * @param prediction
+ * The predicted class label from the decision tree model.
+ * @param inEvent
+ * The associated instance content event
+ * @return ResultContentEvent to be sent into Evaluator PI or other
+ * destination PI.
+ */
+ private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
+ ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
+ inEvent.getClassId(), prediction, inEvent.isLastEvent());
+ rce.setClassifierIndex(this.processorId);
+ rce.setEvaluationIndex(inEvent.getEvaluationIndex());
+ return rce;
+ }
+
+ public ErrorWeightedVote newErrorWeightedVote() {
+ if (voteType == 1)
+ return new UniformWeightedVote();
+ return new InverseErrorWeightedVote();
+ }
+
+ /**
+ * Method to verify if the instance is an anomaly.
+ *
+ * @param instance
+ * @param rule
+ * @return
+ */
+ private boolean isAnomaly(Instance instance, LearningRule rule) {
+ // AMRUles is equipped with anomaly detection. If on, compute the anomaly
+ // value.
+ boolean isAnomaly = false;
+ if (this.noAnomalyDetection == false) {
+ if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) {
+ isAnomaly = rule.isAnomaly(instance,
+ this.univariateAnomalyprobabilityThreshold,
+ this.multivariateAnomalyProbabilityThreshold,
+ this.anomalyNumInstThreshold);
+ }
+ }
+ return isAnomaly;
+ }
+
+ /*
+ * Create new rules
+ */
+ private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) {
+ ActiveRule r = newRule(ID);
+
+ if (node != null)
+ {
+ if (node.getPerceptron() != null)
+ {
+ r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron()));
+ r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio);
+ }
+ if (statistics == null)
+ {
+ double mean;
+ if (node.getNodeStatistics().getValue(0) > 0) {
+ mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0);
+ r.getLearningNode().getTargetMean().reset(mean, 1);
+ }
+ }
+ }
+ if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null)
+ {
+ double mean;
+ if (statistics[0] > 0) {
+ mean = statistics[1] / statistics[0];
+ ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]);
+ }
+ }
+ return r;
+ }
+
+ private ActiveRule newRule(int ID) {
+ ActiveRule r = new ActiveRule.Builder().
+ threshold(this.pageHinckleyThreshold).
+ alpha(this.pageHinckleyAlpha).
+ changeDetection(this.driftDetection).
+ predictionFunction(this.predictionFunction).
+ statistics(new double[3]).
+ learningRatio(this.learningRatio).
+ numericObserver(numericObserver).
+ id(ID).build();
+ return r;
+ }
+
+ /*
+ * Add predicate/RuleSplitNode for a rule
+ */
+ private void updateRuleSplitNode(PredicateContentEvent pce) {
+ int ruleID = pce.getRuleNumberID();
+ for (PassiveRule rule : ruleSet) {
+ if (rule.getRuleNumberID() == ruleID) {
+ if (pce.getRuleSplitNode() != null)
+ rule.nodeListAdd(pce.getRuleSplitNode());
+ if (pce.getLearningNode() != null)
+ rule.setLearningNode(pce.getLearningNode());
+ }
+ }
+ }
+
+ /*
+ * Remove rule
+ */
+ private void removeRule(int ruleID) {
+ for (PassiveRule rule : ruleSet) {
+ if (rule.getRuleNumberID() == ruleID) {
+ ruleSet.remove(rule);
+ break;
+ }
+ }
+ }
+
+ @Override
+ public void onCreate(int id) {
+ this.processorId = id;
+ this.statistics = new double[] { 0.0, 0, 0 };
+ this.ruleNumberID = 0;
+ this.defaultRule = newRule(++this.ruleNumberID);
+
+ this.ruleSet = new LinkedList<PassiveRule>();
+ }
+
+ /*
+ * Clone processor
+ */
+ @Override
+ public Processor newProcessor(Processor p) {
+ AMRulesAggregatorProcessor oldProcessor = (AMRulesAggregatorProcessor) p;
+ Builder builder = new Builder(oldProcessor);
+ AMRulesAggregatorProcessor newProcessor = builder.build();
+ newProcessor.resultStream = oldProcessor.resultStream;
+ newProcessor.statisticsStream = oldProcessor.statisticsStream;
+ return newProcessor;
+ }
+
+ /*
+ * Send events
+ */
+ private void sendInstanceToRule(Instance instance, int ruleID) {
+ AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance);
+ this.statisticsStream.put(ace);
+ }
+
+ private void sendAddRuleEvent(int ruleID, ActiveRule rule) {
+ RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false);
+ this.statisticsStream.put(rce);
+ }
+
+ /*
+ * Output streams
+ */
+ public void setStatisticsStream(Stream statisticsStream) {
+ this.statisticsStream = statisticsStream;
+ }
+
+ public Stream getStatisticsStream() {
+ return this.statisticsStream;
+ }
+
+ public void setResultStream(Stream resultStream) {
+ this.resultStream = resultStream;
+ }
+
+ public Stream getResultStream() {
+ return this.resultStream;
+ }
+
+ /*
+ * Others
+ */
+ public boolean isRandomizable() {
+ return true;
+ }
+
+ /*
+ * Builder
+ */
+ public static class Builder {
+ private int pageHinckleyThreshold;
+ private double pageHinckleyAlpha;
+ private boolean driftDetection;
+ private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2
+ private boolean constantLearningRatioDecay;
+ private double learningRatio;
+ private double splitConfidence;
+ private double tieThreshold;
+ private int gracePeriod;
+
+ private boolean noAnomalyDetection;
+ private double multivariateAnomalyProbabilityThreshold;
+ private double univariateAnomalyprobabilityThreshold;
+ private int anomalyNumInstThreshold;
+
+ private boolean unorderedRules;
+
+ private FIMTDDNumericAttributeClassLimitObserver numericObserver;
+ private int voteType;
+
+ private Instances dataset;
+
+ public Builder(Instances dataset) {
+ this.dataset = dataset;
+ }
+
+ public Builder(AMRulesAggregatorProcessor processor) {
+ this.pageHinckleyThreshold = processor.pageHinckleyThreshold;
+ this.pageHinckleyAlpha = processor.pageHinckleyAlpha;
+ this.driftDetection = processor.driftDetection;
+ this.predictionFunction = processor.predictionFunction;
+ this.constantLearningRatioDecay = processor.constantLearningRatioDecay;
+ this.learningRatio = processor.learningRatio;
+ this.splitConfidence = processor.splitConfidence;
+ this.tieThreshold = processor.tieThreshold;
+ this.gracePeriod = processor.gracePeriod;
+
+ this.noAnomalyDetection = processor.noAnomalyDetection;
+ this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold;
+ this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold;
+ this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold;
+ this.unorderedRules = processor.unorderedRules;
+
+ this.numericObserver = processor.numericObserver;
+ this.voteType = processor.voteType;
+ }
+
+ public Builder threshold(int threshold) {
+ this.pageHinckleyThreshold = threshold;
+ return this;
+ }
+
+ public Builder alpha(double alpha) {
+ this.pageHinckleyAlpha = alpha;
+ return this;
+ }
+
+ public Builder changeDetection(boolean changeDetection) {
+ this.driftDetection = changeDetection;
+ return this;
+ }
+
+ public Builder predictionFunction(int predictionFunction) {
+ this.predictionFunction = predictionFunction;
+ return this;
+ }
+
+ public Builder constantLearningRatioDecay(boolean constantDecay) {
+ this.constantLearningRatioDecay = constantDecay;
+ return this;
+ }
+
+ public Builder learningRatio(double learningRatio) {
+ this.learningRatio = learningRatio;
+ return this;
+ }
+
+ public Builder splitConfidence(double splitConfidence) {
+ this.splitConfidence = splitConfidence;
+ return this;
+ }
+
+ public Builder tieThreshold(double tieThreshold) {
+ this.tieThreshold = tieThreshold;
+ return this;
+ }
+
+ public Builder gracePeriod(int gracePeriod) {
+ this.gracePeriod = gracePeriod;
+ return this;
+ }
+
+ public Builder noAnomalyDetection(boolean noAnomalyDetection) {
+ this.noAnomalyDetection = noAnomalyDetection;
+ return this;
+ }
+
+ public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) {
+ this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold;
+ return this;
+ }
+
+ public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) {
+ this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold;
+ return this;
+ }
+
+ public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) {
+ this.anomalyNumInstThreshold = anomalyNumInstThreshold;
+ return this;
+ }
+
+ public Builder unorderedRules(boolean unorderedRules) {
+ this.unorderedRules = unorderedRules;
+ return this;
+ }
+
+ public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) {
+ this.numericObserver = numericObserver;
+ return this;
+ }
+
+ public Builder voteType(int voteType) {
+ this.voteType = voteType;
+ return this;
+ }
+
+ public AMRulesAggregatorProcessor build() {
+ return new AMRulesAggregatorProcessor(this);
+ }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java
index da820d8..2f1cb18 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java
@@ -41,180 +41,180 @@
* Learner Processor (VAMR).
*
* @author Anh Thu Vu
- *
+ *
*/
public class AMRulesStatisticsProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 5268933189695395573L;
-
- private static final Logger logger =
- LoggerFactory.getLogger(AMRulesStatisticsProcessor.class);
-
- private int processorId;
-
- private transient List<ActiveRule> ruleSet;
-
- private Stream outputStream;
-
- private double splitConfidence;
- private double tieThreshold;
- private int gracePeriod;
-
- private int frequency;
-
- public AMRulesStatisticsProcessor(Builder builder) {
- this.splitConfidence = builder.splitConfidence;
- this.tieThreshold = builder.tieThreshold;
- this.gracePeriod = builder.gracePeriod;
- this.frequency = builder.frequency;
- }
+ private static final long serialVersionUID = 5268933189695395573L;
- @Override
- public boolean process(ContentEvent event) {
- if (event instanceof AssignmentContentEvent) {
-
- AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event;
- trainRuleOnInstance(attrContentEvent.getRuleNumberID(),attrContentEvent.getInstance());
- }
- else if (event instanceof RuleContentEvent) {
- RuleContentEvent ruleContentEvent = (RuleContentEvent) event;
- if (!ruleContentEvent.isRemoving()) {
- addRule(ruleContentEvent.getRule());
- }
- }
-
- return false;
- }
-
- /*
- * Process input instances
- */
- private void trainRuleOnInstance(int ruleID, Instance instance) {
- Iterator<ActiveRule> ruleIterator= this.ruleSet.iterator();
- while (ruleIterator.hasNext()) {
- ActiveRule rule = ruleIterator.next();
- if (rule.getRuleNumberID() == ruleID) {
- // Check (again) for coverage
- // Skip anomaly check as Aggregator's perceptron should be well-updated
- if (rule.isCovering(instance) == true) {
- double error = rule.computeError(instance); //Use adaptive mode error
- boolean changeDetected = ((RuleActiveRegressionNode)rule.getLearningNode()).updateChangeDetection(error);
- if (changeDetected == true) {
- ruleIterator.remove();
-
- this.sendRemoveRuleEvent(ruleID);
- } else {
- rule.updateStatistics(instance);
- if (rule.getInstancesSeen() % this.gracePeriod == 0.0) {
- if (rule.tryToExpand(this.splitConfidence, this.tieThreshold) ) {
- rule.split();
-
- // expanded: update Aggregator with new/updated predicate
- this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(),
- (RuleActiveRegressionNode)rule.getLearningNode());
- }
- }
- }
- }
-
- return;
- }
- }
- }
-
- private void sendRemoveRuleEvent(int ruleID) {
- RuleContentEvent rce = new RuleContentEvent(ruleID, null, true);
- this.outputStream.put(rce);
- }
-
- private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) {
- this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode)));
- }
-
- /*
- * Process control message (regarding adding or removing rules)
- */
- private boolean addRule(ActiveRule rule) {
- this.ruleSet.add(rule);
- return true;
- }
+ private static final Logger logger =
+ LoggerFactory.getLogger(AMRulesStatisticsProcessor.class);
- @Override
- public void onCreate(int id) {
- this.processorId = id;
- this.ruleSet = new LinkedList<ActiveRule>();
- }
+ private int processorId;
- @Override
- public Processor newProcessor(Processor p) {
- AMRulesStatisticsProcessor oldProcessor = (AMRulesStatisticsProcessor)p;
- AMRulesStatisticsProcessor newProcessor =
- new AMRulesStatisticsProcessor.Builder(oldProcessor).build();
-
- newProcessor.setOutputStream(oldProcessor.outputStream);
- return newProcessor;
- }
-
- /*
- * Builder
- */
- public static class Builder {
- private double splitConfidence;
- private double tieThreshold;
- private int gracePeriod;
-
- private int frequency;
-
- private Instances dataset;
-
- public Builder(Instances dataset){
- this.dataset = dataset;
- }
-
- public Builder(AMRulesStatisticsProcessor processor) {
- this.splitConfidence = processor.splitConfidence;
- this.tieThreshold = processor.tieThreshold;
- this.gracePeriod = processor.gracePeriod;
- this.frequency = processor.frequency;
- }
-
- public Builder splitConfidence(double splitConfidence) {
- this.splitConfidence = splitConfidence;
- return this;
- }
-
- public Builder tieThreshold(double tieThreshold) {
- this.tieThreshold = tieThreshold;
- return this;
- }
-
- public Builder gracePeriod(int gracePeriod) {
- this.gracePeriod = gracePeriod;
- return this;
- }
-
- public Builder frequency(int frequency) {
- this.frequency = frequency;
- return this;
- }
-
- public AMRulesStatisticsProcessor build() {
- return new AMRulesStatisticsProcessor(this);
- }
- }
-
- /*
- * Output stream
- */
- public void setOutputStream(Stream stream) {
- this.outputStream = stream;
- }
-
- public Stream getOutputStream() {
- return this.outputStream;
- }
+ private transient List<ActiveRule> ruleSet;
+
+ private Stream outputStream;
+
+ private double splitConfidence;
+ private double tieThreshold;
+ private int gracePeriod;
+
+ private int frequency;
+
+ public AMRulesStatisticsProcessor(Builder builder) {
+ this.splitConfidence = builder.splitConfidence;
+ this.tieThreshold = builder.tieThreshold;
+ this.gracePeriod = builder.gracePeriod;
+ this.frequency = builder.frequency;
+ }
+
+ @Override
+ public boolean process(ContentEvent event) {
+ if (event instanceof AssignmentContentEvent) {
+
+ AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event;
+ trainRuleOnInstance(attrContentEvent.getRuleNumberID(), attrContentEvent.getInstance());
+ }
+ else if (event instanceof RuleContentEvent) {
+ RuleContentEvent ruleContentEvent = (RuleContentEvent) event;
+ if (!ruleContentEvent.isRemoving()) {
+ addRule(ruleContentEvent.getRule());
+ }
+ }
+
+ return false;
+ }
+
+ /*
+ * Process input instances
+ */
+ private void trainRuleOnInstance(int ruleID, Instance instance) {
+ Iterator<ActiveRule> ruleIterator = this.ruleSet.iterator();
+ while (ruleIterator.hasNext()) {
+ ActiveRule rule = ruleIterator.next();
+ if (rule.getRuleNumberID() == ruleID) {
+ // Check (again) for coverage
+ // Skip anomaly check as Aggregator's perceptron should be well-updated
+ if (rule.isCovering(instance) == true) {
+ double error = rule.computeError(instance); // Use adaptive mode error
+ boolean changeDetected = ((RuleActiveRegressionNode) rule.getLearningNode()).updateChangeDetection(error);
+ if (changeDetected == true) {
+ ruleIterator.remove();
+
+ this.sendRemoveRuleEvent(ruleID);
+ } else {
+ rule.updateStatistics(instance);
+ if (rule.getInstancesSeen() % this.gracePeriod == 0.0) {
+ if (rule.tryToExpand(this.splitConfidence, this.tieThreshold)) {
+ rule.split();
+
+ // expanded: update Aggregator with new/updated predicate
+ this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(),
+ (RuleActiveRegressionNode) rule.getLearningNode());
+ }
+ }
+ }
+ }
+
+ return;
+ }
+ }
+ }
+
+ private void sendRemoveRuleEvent(int ruleID) {
+ RuleContentEvent rce = new RuleContentEvent(ruleID, null, true);
+ this.outputStream.put(rce);
+ }
+
+ private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) {
+ this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode)));
+ }
+
+ /*
+ * Process control message (regarding adding or removing rules)
+ */
+ private boolean addRule(ActiveRule rule) {
+ this.ruleSet.add(rule);
+ return true;
+ }
+
+ @Override
+ public void onCreate(int id) {
+ this.processorId = id;
+ this.ruleSet = new LinkedList<ActiveRule>();
+ }
+
+ @Override
+ public Processor newProcessor(Processor p) {
+ AMRulesStatisticsProcessor oldProcessor = (AMRulesStatisticsProcessor) p;
+ AMRulesStatisticsProcessor newProcessor =
+ new AMRulesStatisticsProcessor.Builder(oldProcessor).build();
+
+ newProcessor.setOutputStream(oldProcessor.outputStream);
+ return newProcessor;
+ }
+
+ /*
+ * Builder
+ */
+ public static class Builder {
+ private double splitConfidence;
+ private double tieThreshold;
+ private int gracePeriod;
+
+ private int frequency;
+
+ private Instances dataset;
+
+ public Builder(Instances dataset) {
+ this.dataset = dataset;
+ }
+
+ public Builder(AMRulesStatisticsProcessor processor) {
+ this.splitConfidence = processor.splitConfidence;
+ this.tieThreshold = processor.tieThreshold;
+ this.gracePeriod = processor.gracePeriod;
+ this.frequency = processor.frequency;
+ }
+
+ public Builder splitConfidence(double splitConfidence) {
+ this.splitConfidence = splitConfidence;
+ return this;
+ }
+
+ public Builder tieThreshold(double tieThreshold) {
+ this.tieThreshold = tieThreshold;
+ return this;
+ }
+
+ public Builder gracePeriod(int gracePeriod) {
+ this.gracePeriod = gracePeriod;
+ return this;
+ }
+
+ public Builder frequency(int frequency) {
+ this.frequency = frequency;
+ return this;
+ }
+
+ public AMRulesStatisticsProcessor build() {
+ return new AMRulesStatisticsProcessor(this);
+ }
+ }
+
+ /*
+ * Output stream
+ */
+ public void setOutputStream(Stream stream) {
+ this.outputStream = stream;
+ }
+
+ public Stream getOutputStream() {
+ return this.outputStream;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java
index 5a03406..43814b4 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java
@@ -27,48 +27,48 @@
* Forwarded instances from Model Agrregator to Learners/Default Rule Learner.
*
* @author Anh Thu Vu
- *
+ *
*/
public class AssignmentContentEvent implements ContentEvent {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 1031695762172836629L;
+ private static final long serialVersionUID = 1031695762172836629L;
- private int ruleNumberID;
- private Instance instance;
-
- public AssignmentContentEvent() {
- this(0, null);
- }
-
- public AssignmentContentEvent(int ruleID, Instance instance) {
- this.ruleNumberID = ruleID;
- this.instance = instance;
- }
-
- @Override
- public String getKey() {
- return Integer.toString(this.ruleNumberID);
- }
+ private int ruleNumberID;
+ private Instance instance;
- @Override
- public void setKey(String key) {
- // do nothing
- }
+ public AssignmentContentEvent() {
+ this(0, null);
+ }
- @Override
- public boolean isLastEvent() {
- return false;
- }
-
- public Instance getInstance() {
- return this.instance;
- }
-
- public int getRuleNumberID() {
- return this.ruleNumberID;
- }
+ public AssignmentContentEvent(int ruleID, Instance instance) {
+ this.ruleNumberID = ruleID;
+ this.instance = instance;
+ }
+
+ @Override
+ public String getKey() {
+ return Integer.toString(this.ruleNumberID);
+ }
+
+ @Override
+ public void setKey(String key) {
+ // do nothing
+ }
+
+ @Override
+ public boolean isLastEvent() {
+ return false;
+ }
+
+ public Instance getInstance() {
+ return this.instance;
+ }
+
+ public int getRuleNumberID() {
+ return this.ruleNumberID;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java
index 69e935a..f6c8934 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java
@@ -28,57 +28,58 @@
* New features (of newly expanded rules) from Learners to Model Aggregators.
*
* @author Anh Thu Vu
- *
+ *
*/
public class PredicateContentEvent implements ContentEvent {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 7909435830443732451L;
-
- private int ruleNumberID;
- private RuleSplitNode ruleSplitNode;
- private RulePassiveRegressionNode learningNode;
-
- /*
- * Constructor
- */
- public PredicateContentEvent() {
- this(0, null, null);
- }
-
- public PredicateContentEvent (int ruleID, RuleSplitNode ruleSplitNode, RulePassiveRegressionNode learningNode) {
- this.ruleNumberID = ruleID;
- this.ruleSplitNode = ruleSplitNode; // is this is null: this is for updating learningNode only
- this.learningNode = learningNode;
- }
-
- @Override
- public String getKey() {
- return Integer.toString(this.ruleNumberID);
- }
+ private static final long serialVersionUID = 7909435830443732451L;
- @Override
- public void setKey(String key) {
- // do nothing
- }
+ private int ruleNumberID;
+ private RuleSplitNode ruleSplitNode;
+ private RulePassiveRegressionNode learningNode;
- @Override
- public boolean isLastEvent() {
- return false; // N/A
- }
-
- public int getRuleNumberID() {
- return this.ruleNumberID;
- }
-
- public RuleSplitNode getRuleSplitNode() {
- return this.ruleSplitNode;
- }
-
- public RulePassiveRegressionNode getLearningNode() {
- return this.learningNode;
- }
+ /*
+ * Constructor
+ */
+ public PredicateContentEvent() {
+ this(0, null, null);
+ }
+
+ public PredicateContentEvent(int ruleID, RuleSplitNode ruleSplitNode, RulePassiveRegressionNode learningNode) {
+ this.ruleNumberID = ruleID;
+ this.ruleSplitNode = ruleSplitNode; // is this is null: this is for updating
+ // learningNode only
+ this.learningNode = learningNode;
+ }
+
+ @Override
+ public String getKey() {
+ return Integer.toString(this.ruleNumberID);
+ }
+
+ @Override
+ public void setKey(String key) {
+ // do nothing
+ }
+
+ @Override
+ public boolean isLastEvent() {
+ return false; // N/A
+ }
+
+ public int getRuleNumberID() {
+ return this.ruleNumberID;
+ }
+
+ public RuleSplitNode getRuleSplitNode() {
+ return this.ruleSplitNode;
+ }
+
+ public RulePassiveRegressionNode getLearningNode() {
+ return this.learningNode;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java
index a9dab4a..ac7aced 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java
@@ -24,59 +24,59 @@
import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule;
/**
- * New rule from Model Aggregator/Default Rule Learner to Learners
- * or removed rule from Learner to Model Aggregators.
+ * New rule from Model Aggregator/Default Rule Learner to Learners or removed
+ * rule from Learner to Model Aggregators.
*
* @author Anh Thu Vu
- *
+ *
*/
public class RuleContentEvent implements ContentEvent {
-
- /**
+ /**
*
*/
- private static final long serialVersionUID = -9046390274402894461L;
-
- private final int ruleNumberID;
- private final ActiveRule addingRule; // for removing rule, we only need the rule's ID
- private final boolean isRemoving;
-
- public RuleContentEvent() {
- this(0, null, false);
- }
-
- public RuleContentEvent(int ruleID, ActiveRule rule, boolean isRemoving) {
- this.ruleNumberID = ruleID;
- this.isRemoving = isRemoving;
- this.addingRule = rule;
- }
+ private static final long serialVersionUID = -9046390274402894461L;
- @Override
- public String getKey() {
- return Integer.toString(this.ruleNumberID);
- }
+ private final int ruleNumberID;
+ private final ActiveRule addingRule; // for removing rule, we only need the
+ // rule's ID
+ private final boolean isRemoving;
- @Override
- public void setKey(String key) {
- // do nothing
- }
+ public RuleContentEvent() {
+ this(0, null, false);
+ }
- @Override
- public boolean isLastEvent() {
- return false;
- }
-
- public int getRuleNumberID() {
- return this.ruleNumberID;
- }
-
- public ActiveRule getRule() {
- return this.addingRule;
- }
-
- public boolean isRemoving() {
- return this.isRemoving;
- }
+ public RuleContentEvent(int ruleID, ActiveRule rule, boolean isRemoving) {
+ this.ruleNumberID = ruleID;
+ this.isRemoving = isRemoving;
+ this.addingRule = rule;
+ }
+
+ @Override
+ public String getKey() {
+ return Integer.toString(this.ruleNumberID);
+ }
+
+ @Override
+ public void setKey(String key) {
+ // do nothing
+ }
+
+ @Override
+ public boolean isLastEvent() {
+ return false;
+ }
+
+ public int getRuleNumberID() {
+ return this.ruleNumberID;
+ }
+
+ public ActiveRule getRule() {
+ return this.addingRule;
+ }
+
+ public boolean isRemoving() {
+ return this.isRemoving;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java
index 40d260c..39abbbe 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java
@@ -31,177 +31,178 @@
import com.yahoo.labs.samoa.instances.Instance;
final class ActiveLearningNode extends LearningNode {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -2892102872646338908L;
- private static final Logger logger = LoggerFactory.getLogger(ActiveLearningNode.class);
-
- private double weightSeenAtLastSplitEvaluation;
-
- private final Map<Integer, String> attributeContentEventKeys;
-
- private AttributeSplitSuggestion bestSuggestion;
- private AttributeSplitSuggestion secondBestSuggestion;
-
- private final long id;
- private final int parallelismHint;
- private int suggestionCtr;
- private int thrownAwayInstance;
-
- private boolean isSplitting;
-
- ActiveLearningNode(double[] classObservation, int parallelismHint) {
- super(classObservation);
- this.weightSeenAtLastSplitEvaluation = this.getWeightSeen();
- this.id = VerticalHoeffdingTree.LearningNodeIdGenerator.generate();
- this.attributeContentEventKeys = new HashMap<>();
- this.isSplitting = false;
- this.parallelismHint = parallelismHint;
- }
-
- long getId(){
- return id;
- }
+ private static final long serialVersionUID = -2892102872646338908L;
+ private static final Logger logger = LoggerFactory.getLogger(ActiveLearningNode.class);
- protected AttributeBatchContentEvent[] attributeBatchContentEvent;
+ private double weightSeenAtLastSplitEvaluation;
- public AttributeBatchContentEvent[] getAttributeBatchContentEvent() {
- return this.attributeBatchContentEvent;
+ private final Map<Integer, String> attributeContentEventKeys;
+
+ private AttributeSplitSuggestion bestSuggestion;
+ private AttributeSplitSuggestion secondBestSuggestion;
+
+ private final long id;
+ private final int parallelismHint;
+ private int suggestionCtr;
+ private int thrownAwayInstance;
+
+ private boolean isSplitting;
+
+ ActiveLearningNode(double[] classObservation, int parallelismHint) {
+ super(classObservation);
+ this.weightSeenAtLastSplitEvaluation = this.getWeightSeen();
+ this.id = VerticalHoeffdingTree.LearningNodeIdGenerator.generate();
+ this.attributeContentEventKeys = new HashMap<>();
+ this.isSplitting = false;
+ this.parallelismHint = parallelismHint;
+ }
+
+ long getId() {
+ return id;
+ }
+
+ protected AttributeBatchContentEvent[] attributeBatchContentEvent;
+
+ public AttributeBatchContentEvent[] getAttributeBatchContentEvent() {
+ return this.attributeBatchContentEvent;
+ }
+
+ public void setAttributeBatchContentEvent(AttributeBatchContentEvent[] attributeBatchContentEvent) {
+ this.attributeBatchContentEvent = attributeBatchContentEvent;
+ }
+
+ @Override
+ void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) {
+ // TODO: what statistics should we keep for unused instance?
+ if (isSplitting) { // currently throw all instance will splitting
+ this.thrownAwayInstance++;
+ return;
+ }
+ this.observedClassDistribution.addToValue((int) inst.classValue(),
+ inst.weight());
+ // done: parallelize by sending attributes one by one
+ // TODO: meanwhile, we can try to use the ThreadPool to execute it
+ // separately
+ // TODO: parallelize by sending in batch, i.e. split the attributes into
+ // chunk instead of send the attribute one by one
+ for (int i = 0; i < inst.numAttributes() - 1; i++) {
+ int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
+ Integer obsIndex = i;
+ String key = attributeContentEventKeys.get(obsIndex);
+
+ if (key == null) {
+ key = this.generateKey(i);
+ attributeContentEventKeys.put(obsIndex, key);
+ }
+ AttributeContentEvent ace = new AttributeContentEvent.Builder(
+ this.id, i, key)
+ .attrValue(inst.value(instAttIndex))
+ .classValue((int) inst.classValue())
+ .weight(inst.weight())
+ .isNominal(inst.attribute(instAttIndex).isNominal())
+ .build();
+ if (this.attributeBatchContentEvent == null) {
+ this.attributeBatchContentEvent = new AttributeBatchContentEvent[inst.numAttributes() - 1];
+ }
+ if (this.attributeBatchContentEvent[i] == null) {
+ this.attributeBatchContentEvent[i] = new AttributeBatchContentEvent.Builder(
+ this.id, i, key)
+ // .attrValue(inst.value(instAttIndex))
+ // .classValue((int) inst.classValue())
+ // .weight(inst.weight()]
+ .isNominal(inst.attribute(instAttIndex).isNominal())
+ .build();
+ }
+ this.attributeBatchContentEvent[i].add(ace);
+ // proc.sendToAttributeStream(ace);
+ }
+ }
+
+ @Override
+ double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) {
+ return this.observedClassDistribution.getArrayCopy();
+ }
+
+ double getWeightSeen() {
+ return this.observedClassDistribution.sumOfValues();
+ }
+
+ void setWeightSeenAtLastSplitEvaluation(double weight) {
+ this.weightSeenAtLastSplitEvaluation = weight;
+ }
+
+ double getWeightSeenAtLastSplitEvaluation() {
+ return this.weightSeenAtLastSplitEvaluation;
+ }
+
+ void requestDistributedSuggestions(long splitId, ModelAggregatorProcessor modelAggrProc) {
+ this.isSplitting = true;
+ this.suggestionCtr = 0;
+ this.thrownAwayInstance = 0;
+
+ ComputeContentEvent cce = new ComputeContentEvent(splitId, this.id,
+ this.getObservedClassDistribution());
+ modelAggrProc.sendToControlStream(cce);
+ }
+
+ void addDistributedSuggestions(AttributeSplitSuggestion bestSuggestion, AttributeSplitSuggestion secondBestSuggestion) {
+ // starts comparing from the best suggestion
+ if (bestSuggestion != null) {
+ if ((this.bestSuggestion == null) || (bestSuggestion.compareTo(this.bestSuggestion) > 0)) {
+ this.secondBestSuggestion = this.bestSuggestion;
+ this.bestSuggestion = bestSuggestion;
+
+ if (secondBestSuggestion != null) {
+
+ if ((this.secondBestSuggestion == null) || (secondBestSuggestion.compareTo(this.secondBestSuggestion) > 0)) {
+ this.secondBestSuggestion = secondBestSuggestion;
+ }
+ }
+ } else {
+ if ((this.secondBestSuggestion == null) || (bestSuggestion.compareTo(this.secondBestSuggestion) > 0)) {
+ this.secondBestSuggestion = bestSuggestion;
+ }
+ }
}
- public void setAttributeBatchContentEvent(AttributeBatchContentEvent[] attributeBatchContentEvent) {
- this.attributeBatchContentEvent = attributeBatchContentEvent;
- }
-
- @Override
- void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) {
- //TODO: what statistics should we keep for unused instance?
- if(isSplitting){ //currently throw all instance will splitting
- this.thrownAwayInstance++;
- return;
- }
- this.observedClassDistribution.addToValue((int)inst.classValue(),
- inst.weight());
- //done: parallelize by sending attributes one by one
- //TODO: meanwhile, we can try to use the ThreadPool to execute it separately
- //TODO: parallelize by sending in batch, i.e. split the attributes into
- //chunk instead of send the attribute one by one
- for(int i = 0; i < inst.numAttributes() - 1; i++){
- int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst);
- Integer obsIndex = i;
- String key = attributeContentEventKeys.get(obsIndex);
-
- if(key == null){
- key = this.generateKey(i);
- attributeContentEventKeys.put(obsIndex, key);
- }
- AttributeContentEvent ace = new AttributeContentEvent.Builder(
- this.id, i, key)
- .attrValue(inst.value(instAttIndex))
- .classValue((int) inst.classValue())
- .weight(inst.weight())
- .isNominal(inst.attribute(instAttIndex).isNominal())
- .build();
- if (this.attributeBatchContentEvent == null){
- this.attributeBatchContentEvent = new AttributeBatchContentEvent[inst.numAttributes() - 1];
- }
- if (this.attributeBatchContentEvent[i] == null){
- this.attributeBatchContentEvent[i] = new AttributeBatchContentEvent.Builder(
- this.id, i, key)
- //.attrValue(inst.value(instAttIndex))
- //.classValue((int) inst.classValue())
- //.weight(inst.weight()]
- .isNominal(inst.attribute(instAttIndex).isNominal())
- .build();
- }
- this.attributeBatchContentEvent[i].add(ace);
- //proc.sendToAttributeStream(ace);
- }
- }
-
- @Override
- double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) {
- return this.observedClassDistribution.getArrayCopy();
- }
-
- double getWeightSeen(){
- return this.observedClassDistribution.sumOfValues();
- }
-
- void setWeightSeenAtLastSplitEvaluation(double weight){
- this.weightSeenAtLastSplitEvaluation = weight;
- }
-
- double getWeightSeenAtLastSplitEvaluation(){
- return this.weightSeenAtLastSplitEvaluation;
- }
-
- void requestDistributedSuggestions(long splitId, ModelAggregatorProcessor modelAggrProc) {
- this.isSplitting = true;
- this.suggestionCtr = 0;
- this.thrownAwayInstance = 0;
-
- ComputeContentEvent cce = new ComputeContentEvent(splitId, this.id,
- this.getObservedClassDistribution());
- modelAggrProc.sendToControlStream(cce);
- }
-
- void addDistributedSuggestions(AttributeSplitSuggestion bestSuggestion, AttributeSplitSuggestion secondBestSuggestion){
- //starts comparing from the best suggestion
- if(bestSuggestion != null){
- if((this.bestSuggestion == null) || (bestSuggestion.compareTo(this.bestSuggestion) > 0)){
- this.secondBestSuggestion = this.bestSuggestion;
- this.bestSuggestion = bestSuggestion;
-
- if(secondBestSuggestion != null){
-
- if((this.secondBestSuggestion == null) || (secondBestSuggestion.compareTo(this.secondBestSuggestion) > 0)){
- this.secondBestSuggestion = secondBestSuggestion;
- }
- }
- }else{
- if((this.secondBestSuggestion == null) || (bestSuggestion.compareTo(this.secondBestSuggestion) > 0)){
- this.secondBestSuggestion = bestSuggestion;
- }
- }
- }
-
- //TODO: optimize the code to use less memory
- this.suggestionCtr++;
- }
+ // TODO: optimize the code to use less memory
+ this.suggestionCtr++;
+ }
- boolean isSplitting(){
- return this.isSplitting;
- }
-
- void endSplitting(){
- this.isSplitting = false;
- logger.trace("wasted instance: {}", this.thrownAwayInstance);
- this.thrownAwayInstance = 0;
- }
-
- AttributeSplitSuggestion getDistributedBestSuggestion(){
- return this.bestSuggestion;
- }
-
- AttributeSplitSuggestion getDistributedSecondBestSuggestion(){
- return this.secondBestSuggestion;
- }
-
- boolean isAllSuggestionsCollected(){
- return (this.suggestionCtr == this.parallelismHint);
- }
-
- private static int modelAttIndexToInstanceAttIndex(int index, Instance inst){
- return inst.classIndex() > index ? index : index + 1;
- }
-
- private String generateKey(int obsIndex){
- final int prime = 31;
- int result = 1;
- result = prime * result + (int) (this.id ^ (this.id >>> 32));
- result = prime * result + obsIndex;
- return Integer.toString(result);
- }
+ boolean isSplitting() {
+ return this.isSplitting;
+ }
+
+ void endSplitting() {
+ this.isSplitting = false;
+ logger.trace("wasted instance: {}", this.thrownAwayInstance);
+ this.thrownAwayInstance = 0;
+ }
+
+ AttributeSplitSuggestion getDistributedBestSuggestion() {
+ return this.bestSuggestion;
+ }
+
+ AttributeSplitSuggestion getDistributedSecondBestSuggestion() {
+ return this.secondBestSuggestion;
+ }
+
+ boolean isAllSuggestionsCollected() {
+ return (this.suggestionCtr == this.parallelismHint);
+ }
+
+ private static int modelAttIndexToInstanceAttIndex(int index, Instance inst) {
+ return inst.classIndex() > index ? index : index + 1;
+ }
+
+ private String generateKey(int obsIndex) {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + (int) (this.id ^ (this.id >>> 32));
+ result = prime * result + obsIndex;
+ return Integer.toString(result);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java
index 691d0fb..c67efcb 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java
@@ -25,110 +25,112 @@
import java.util.List;
/**
- * Attribute Content Event represents the instances that split vertically
- * based on their attribute
+ * Attribute Content Event represents the instances that split vertically based
+ * on their attribute
+ *
* @author Arinto Murdopo
- *
+ *
*/
final class AttributeBatchContentEvent implements ContentEvent {
- private static final long serialVersionUID = 6652815649846676832L;
+ private static final long serialVersionUID = 6652815649846676832L;
- private final long learningNodeId;
- private final int obsIndex;
- private final List<ContentEvent> contentEventList;
- private final transient String key;
- private final boolean isNominal;
-
- public AttributeBatchContentEvent(){
- learningNodeId = -1;
- obsIndex = -1;
- contentEventList = new LinkedList<>();
- key = "";
- isNominal = true;
- }
-
- private AttributeBatchContentEvent(Builder builder){
- this.learningNodeId = builder.learningNodeId;
- this.obsIndex = builder.obsIndex;
- this.contentEventList = new LinkedList<>();
- if (builder.contentEvent != null) {
- this.contentEventList.add(builder.contentEvent);
- }
- this.isNominal = builder.isNominal;
- this.key = builder.key;
- }
-
- public void add(ContentEvent contentEvent){
- this.contentEventList.add(contentEvent);
- }
-
- @Override
- public String getKey() {
- return this.key;
- }
-
- @Override
- public void setKey(String str) {
- //do nothing, maybe useful when we want to reuse the object for serialization/deserialization purpose
- }
+ private final long learningNodeId;
+ private final int obsIndex;
+ private final List<ContentEvent> contentEventList;
+ private final transient String key;
+ private final boolean isNominal;
- @Override
- public boolean isLastEvent() {
- return false;
- }
-
- long getLearningNodeId(){
- return this.learningNodeId;
- }
-
- int getObsIndex(){
- return this.obsIndex;
- }
+ public AttributeBatchContentEvent() {
+ learningNodeId = -1;
+ obsIndex = -1;
+ contentEventList = new LinkedList<>();
+ key = "";
+ isNominal = true;
+ }
- public List<ContentEvent> getContentEventList(){
- return this.contentEventList;
- }
-
- boolean isNominal(){
- return this.isNominal;
- }
-
- static final class Builder{
-
- //required parameters
- private final long learningNodeId;
- private final int obsIndex;
- private final String key;
-
- private ContentEvent contentEvent;
- private boolean isNominal = false;
-
- Builder(long id, int obsIndex, String key){
- this.learningNodeId = id;
- this.obsIndex = obsIndex;
- this.key = key;
- }
-
- private Builder(long id, int obsIndex){
- this.learningNodeId = id;
- this.obsIndex = obsIndex;
- this.key = "";
- }
-
- Builder contentEvent(ContentEvent contentEvent){
- this.contentEvent = contentEvent;
- return this;
- }
-
- Builder isNominal(boolean val){
- this.isNominal = val;
- return this;
- }
-
- AttributeBatchContentEvent build(){
- return new AttributeBatchContentEvent(this);
- }
- }
-
+ private AttributeBatchContentEvent(Builder builder) {
+ this.learningNodeId = builder.learningNodeId;
+ this.obsIndex = builder.obsIndex;
+ this.contentEventList = new LinkedList<>();
+ if (builder.contentEvent != null) {
+ this.contentEventList.add(builder.contentEvent);
+ }
+ this.isNominal = builder.isNominal;
+ this.key = builder.key;
+ }
+
+ public void add(ContentEvent contentEvent) {
+ this.contentEventList.add(contentEvent);
+ }
+
+ @Override
+ public String getKey() {
+ return this.key;
+ }
+
+ @Override
+ public void setKey(String str) {
+ // do nothing, maybe useful when we want to reuse the object for
+ // serialization/deserialization purpose
+ }
+
+ @Override
+ public boolean isLastEvent() {
+ return false;
+ }
+
+ long getLearningNodeId() {
+ return this.learningNodeId;
+ }
+
+ int getObsIndex() {
+ return this.obsIndex;
+ }
+
+ public List<ContentEvent> getContentEventList() {
+ return this.contentEventList;
+ }
+
+ boolean isNominal() {
+ return this.isNominal;
+ }
+
+ static final class Builder {
+
+ // required parameters
+ private final long learningNodeId;
+ private final int obsIndex;
+ private final String key;
+
+ private ContentEvent contentEvent;
+ private boolean isNominal = false;
+
+ Builder(long id, int obsIndex, String key) {
+ this.learningNodeId = id;
+ this.obsIndex = obsIndex;
+ this.key = key;
+ }
+
+ private Builder(long id, int obsIndex) {
+ this.learningNodeId = id;
+ this.obsIndex = obsIndex;
+ this.key = "";
+ }
+
+ Builder contentEvent(ContentEvent contentEvent) {
+ this.contentEvent = contentEvent;
+ return this;
+ }
+
+ Builder isNominal(boolean val) {
+ this.isNominal = val;
+ return this;
+ }
+
+ AttributeBatchContentEvent build() {
+ return new AttributeBatchContentEvent(this);
+ }
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java
index 4cbdd95..ca45d30 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java
@@ -28,195 +28,198 @@
import com.yahoo.labs.samoa.core.ContentEvent;
/**
- * Attribute Content Event represents the instances that split vertically
- * based on their attribute
+ * Attribute Content Event represents the instances that split vertically based
+ * on their attribute
+ *
* @author Arinto Murdopo
- *
+ *
*/
public final class AttributeContentEvent implements ContentEvent {
- private static final long serialVersionUID = 6652815649846676832L;
+ private static final long serialVersionUID = 6652815649846676832L;
- private final long learningNodeId;
- private final int obsIndex;
- private final double attrVal;
- private final int classVal;
- private final double weight;
- private final transient String key;
- private final boolean isNominal;
-
- public AttributeContentEvent(){
- learningNodeId = -1;
- obsIndex = -1;
- attrVal = 0.0;
- classVal = -1;
- weight = 0.0;
- key = "";
- isNominal = true;
- }
-
- private AttributeContentEvent(Builder builder){
- this.learningNodeId = builder.learningNodeId;
- this.obsIndex = builder.obsIndex;
- this.attrVal = builder.attrVal;
- this.classVal = builder.classVal;
- this.weight = builder.weight;
- this.isNominal = builder.isNominal;
- this.key = builder.key;
- }
-
- @Override
- public String getKey() {
- return this.key;
- }
-
- @Override
- public void setKey(String str) {
- //do nothing, maybe useful when we want to reuse the object for serialization/deserialization purpose
- }
+ private final long learningNodeId;
+ private final int obsIndex;
+ private final double attrVal;
+ private final int classVal;
+ private final double weight;
+ private final transient String key;
+ private final boolean isNominal;
- @Override
- public boolean isLastEvent() {
- return false;
- }
-
- long getLearningNodeId(){
- return this.learningNodeId;
- }
-
- int getObsIndex(){
- return this.obsIndex;
- }
-
- int getClassVal(){
- return this.classVal;
- }
-
- double getAttrVal(){
- return this.attrVal;
- }
-
- double getWeight(){
- return this.weight;
- }
-
- boolean isNominal(){
- return this.isNominal;
- }
-
- static final class Builder{
-
- //required parameters
- private final long learningNodeId;
- private final int obsIndex;
- private final String key;
-
- //optional parameters
- private double attrVal = 0.0;
- private int classVal = 0;
- private double weight = 0.0;
- private boolean isNominal = false;
-
- Builder(long id, int obsIndex, String key){
- this.learningNodeId = id;
- this.obsIndex = obsIndex;
- this.key = key;
- }
-
- private Builder(long id, int obsIndex){
- this.learningNodeId = id;
- this.obsIndex = obsIndex;
- this.key = "";
- }
-
- Builder attrValue(double val){
- this.attrVal = val;
- return this;
- }
-
- Builder classValue(int val){
- this.classVal = val;
- return this;
- }
-
- Builder weight(double val){
- this.weight = val;
- return this;
- }
-
- Builder isNominal(boolean val){
- this.isNominal = val;
- return this;
- }
-
- AttributeContentEvent build(){
- return new AttributeContentEvent(this);
- }
- }
-
- /**
- * The Kryo serializer class for AttributeContentEvent when executing on top of Storm.
- * This class allow us to change the precision of the statistics.
- * @author Arinto Murdopo
- *
- */
- public static final class AttributeCESerializer extends Serializer<AttributeContentEvent>{
+ public AttributeContentEvent() {
+ learningNodeId = -1;
+ obsIndex = -1;
+ attrVal = 0.0;
+ classVal = -1;
+ weight = 0.0;
+ key = "";
+ isNominal = true;
+ }
- private static double PRECISION = 1000000.0;
- @Override
- public void write(Kryo kryo, Output output, AttributeContentEvent event) {
- output.writeLong(event.learningNodeId, true);
- output.writeInt(event.obsIndex, true);
- output.writeDouble(event.attrVal, PRECISION, true);
- output.writeInt(event.classVal, true);
- output.writeDouble(event.weight, PRECISION, true);
- output.writeBoolean(event.isNominal);
- }
+ private AttributeContentEvent(Builder builder) {
+ this.learningNodeId = builder.learningNodeId;
+ this.obsIndex = builder.obsIndex;
+ this.attrVal = builder.attrVal;
+ this.classVal = builder.classVal;
+ this.weight = builder.weight;
+ this.isNominal = builder.isNominal;
+ this.key = builder.key;
+ }
- @Override
- public AttributeContentEvent read(Kryo kryo, Input input,
- Class<AttributeContentEvent> type) {
- AttributeContentEvent ace
- = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true))
- .attrValue(input.readDouble(PRECISION, true))
- .classValue(input.readInt(true))
- .weight(input.readDouble(PRECISION, true))
- .isNominal(input.readBoolean())
- .build();
- return ace;
- }
- }
-
- /**
- * The Kryo serializer class for AttributeContentEvent when executing on top of Storm
- * with full precision of the statistics.
- * @author Arinto Murdopo
- *
- */
- public static final class AttributeCEFullPrecSerializer extends Serializer<AttributeContentEvent>{
+ @Override
+ public String getKey() {
+ return this.key;
+ }
- @Override
- public void write(Kryo kryo, Output output, AttributeContentEvent event) {
- output.writeLong(event.learningNodeId, true);
- output.writeInt(event.obsIndex, true);
- output.writeDouble(event.attrVal);
- output.writeInt(event.classVal, true);
- output.writeDouble(event.weight);
- output.writeBoolean(event.isNominal);
- }
+ @Override
+ public void setKey(String str) {
+ // do nothing, maybe useful when we want to reuse the object for
+ // serialization/deserialization purpose
+ }
- @Override
- public AttributeContentEvent read(Kryo kryo, Input input,
- Class<AttributeContentEvent> type) {
- AttributeContentEvent ace
- = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true))
- .attrValue(input.readDouble())
- .classValue(input.readInt(true))
- .weight(input.readDouble())
- .isNominal(input.readBoolean())
- .build();
- return ace;
- }
-
- }
+ @Override
+ public boolean isLastEvent() {
+ return false;
+ }
+
+ long getLearningNodeId() {
+ return this.learningNodeId;
+ }
+
+ int getObsIndex() {
+ return this.obsIndex;
+ }
+
+ int getClassVal() {
+ return this.classVal;
+ }
+
+ double getAttrVal() {
+ return this.attrVal;
+ }
+
+ double getWeight() {
+ return this.weight;
+ }
+
+ boolean isNominal() {
+ return this.isNominal;
+ }
+
+ static final class Builder {
+
+ // required parameters
+ private final long learningNodeId;
+ private final int obsIndex;
+ private final String key;
+
+ // optional parameters
+ private double attrVal = 0.0;
+ private int classVal = 0;
+ private double weight = 0.0;
+ private boolean isNominal = false;
+
+ Builder(long id, int obsIndex, String key) {
+ this.learningNodeId = id;
+ this.obsIndex = obsIndex;
+ this.key = key;
+ }
+
+ private Builder(long id, int obsIndex) {
+ this.learningNodeId = id;
+ this.obsIndex = obsIndex;
+ this.key = "";
+ }
+
+ Builder attrValue(double val) {
+ this.attrVal = val;
+ return this;
+ }
+
+ Builder classValue(int val) {
+ this.classVal = val;
+ return this;
+ }
+
+ Builder weight(double val) {
+ this.weight = val;
+ return this;
+ }
+
+ Builder isNominal(boolean val) {
+ this.isNominal = val;
+ return this;
+ }
+
+ AttributeContentEvent build() {
+ return new AttributeContentEvent(this);
+ }
+ }
+
+ /**
+ * The Kryo serializer class for AttributeContentEvent when executing on top
+ * of Storm. This class allow us to change the precision of the statistics.
+ *
+ * @author Arinto Murdopo
+ *
+ */
+ public static final class AttributeCESerializer extends Serializer<AttributeContentEvent> {
+
+ private static double PRECISION = 1000000.0;
+
+ @Override
+ public void write(Kryo kryo, Output output, AttributeContentEvent event) {
+ output.writeLong(event.learningNodeId, true);
+ output.writeInt(event.obsIndex, true);
+ output.writeDouble(event.attrVal, PRECISION, true);
+ output.writeInt(event.classVal, true);
+ output.writeDouble(event.weight, PRECISION, true);
+ output.writeBoolean(event.isNominal);
+ }
+
+ @Override
+ public AttributeContentEvent read(Kryo kryo, Input input,
+ Class<AttributeContentEvent> type) {
+ AttributeContentEvent ace = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true))
+ .attrValue(input.readDouble(PRECISION, true))
+ .classValue(input.readInt(true))
+ .weight(input.readDouble(PRECISION, true))
+ .isNominal(input.readBoolean())
+ .build();
+ return ace;
+ }
+ }
+
+ /**
+ * The Kryo serializer class for AttributeContentEvent when executing on top
+ * of Storm with full precision of the statistics.
+ *
+ * @author Arinto Murdopo
+ *
+ */
+ public static final class AttributeCEFullPrecSerializer extends Serializer<AttributeContentEvent> {
+
+ @Override
+ public void write(Kryo kryo, Output output, AttributeContentEvent event) {
+ output.writeLong(event.learningNodeId, true);
+ output.writeInt(event.obsIndex, true);
+ output.writeDouble(event.attrVal);
+ output.writeInt(event.classVal, true);
+ output.writeDouble(event.weight);
+ output.writeBoolean(event.isNominal);
+ }
+
+ @Override
+ public AttributeContentEvent read(Kryo kryo, Input input,
+ Class<AttributeContentEvent> type) {
+ AttributeContentEvent ace = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true))
+ .attrValue(input.readDouble())
+ .classValue(input.readInt(true))
+ .weight(input.readDouble())
+ .isNominal(input.readBoolean())
+ .build();
+ return ace;
+ }
+
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java
index 52f4685..8113a41 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java
@@ -26,117 +26,121 @@
import com.esotericsoftware.kryo.io.Output;
/**
- * Compute content event is the message that is sent by Model Aggregator Processor
- * to request Local Statistic PI to start the local statistic calculation for splitting
+ * Compute content event is the message that is sent by Model Aggregator
+ * Processor to request Local Statistic PI to start the local statistic
+ * calculation for splitting
+ *
* @author Arinto Murdopo
- *
+ *
*/
public final class ComputeContentEvent extends ControlContentEvent {
-
- private static final long serialVersionUID = 5590798490073395190L;
-
- private final double[] preSplitDist;
- private final long splitId;
-
- public ComputeContentEvent(){
- super(-1);
- preSplitDist = null;
- splitId = -1;
- }
- ComputeContentEvent(long splitId, long id, double[] preSplitDist) {
- super(id);
- //this.preSplitDist = Arrays.copyOf(preSplitDist, preSplitDist.length);
- this.preSplitDist = preSplitDist;
- this.splitId = splitId;
- }
+ private static final long serialVersionUID = 5590798490073395190L;
- @Override
- LocStatControl getType() {
- return LocStatControl.COMPUTE;
- }
-
- double[] getPreSplitDist(){
- return this.preSplitDist;
- }
-
- long getSplitId(){
- return this.splitId;
- }
-
- /**
- * The Kryo serializer class for ComputeContentEevent when executing on top of Storm.
- * This class allow us to change the precision of the statistics.
- * @author Arinto Murdopo
- *
- */
- public static final class ComputeCESerializer extends Serializer<ComputeContentEvent>{
+ private final double[] preSplitDist;
+ private final long splitId;
- private static double PRECISION = 1000000.0;
-
- @Override
- public void write(Kryo kryo, Output output, ComputeContentEvent object) {
- output.writeLong(object.splitId, true);
- output.writeLong(object.learningNodeId, true);
-
- output.writeInt(object.preSplitDist.length, true);
- for(int i = 0; i < object.preSplitDist.length; i++){
- output.writeDouble(object.preSplitDist[i], PRECISION, true);
- }
- }
+ public ComputeContentEvent() {
+ super(-1);
+ preSplitDist = null;
+ splitId = -1;
+ }
- @Override
- public ComputeContentEvent read(Kryo kryo, Input input,
- Class<ComputeContentEvent> type) {
- long splitId = input.readLong(true);
- long learningNodeId = input.readLong(true);
-
- int dataLength = input.readInt(true);
- double[] preSplitDist = new double[dataLength];
-
- for(int i = 0; i < dataLength; i++){
- preSplitDist[i] = input.readDouble(PRECISION, true);
- }
-
- return new ComputeContentEvent(splitId, learningNodeId, preSplitDist);
- }
- }
-
- /**
- * The Kryo serializer class for ComputeContentEevent when executing on top of Storm
- * with full precision of the statistics.
- * @author Arinto Murdopo
- *
- */
- public static final class ComputeCEFullPrecSerializer extends Serializer<ComputeContentEvent>{
+ ComputeContentEvent(long splitId, long id, double[] preSplitDist) {
+ super(id);
+ // this.preSplitDist = Arrays.copyOf(preSplitDist, preSplitDist.length);
+ this.preSplitDist = preSplitDist;
+ this.splitId = splitId;
+ }
- @Override
- public void write(Kryo kryo, Output output, ComputeContentEvent object) {
- output.writeLong(object.splitId, true);
- output.writeLong(object.learningNodeId, true);
-
- output.writeInt(object.preSplitDist.length, true);
- for(int i = 0; i < object.preSplitDist.length; i++){
- output.writeDouble(object.preSplitDist[i]);
- }
- }
+ @Override
+ LocStatControl getType() {
+ return LocStatControl.COMPUTE;
+ }
- @Override
- public ComputeContentEvent read(Kryo kryo, Input input,
- Class<ComputeContentEvent> type) {
- long splitId = input.readLong(true);
- long learningNodeId = input.readLong(true);
-
- int dataLength = input.readInt(true);
- double[] preSplitDist = new double[dataLength];
-
- for(int i = 0; i < dataLength; i++){
- preSplitDist[i] = input.readDouble();
- }
-
- return new ComputeContentEvent(splitId, learningNodeId, preSplitDist);
- }
-
- }
+ double[] getPreSplitDist() {
+ return this.preSplitDist;
+ }
+
+ long getSplitId() {
+ return this.splitId;
+ }
+
+ /**
+ * The Kryo serializer class for ComputeContentEevent when executing on top of
+ * Storm. This class allow us to change the precision of the statistics.
+ *
+ * @author Arinto Murdopo
+ *
+ */
+ public static final class ComputeCESerializer extends Serializer<ComputeContentEvent> {
+
+ private static double PRECISION = 1000000.0;
+
+ @Override
+ public void write(Kryo kryo, Output output, ComputeContentEvent object) {
+ output.writeLong(object.splitId, true);
+ output.writeLong(object.learningNodeId, true);
+
+ output.writeInt(object.preSplitDist.length, true);
+ for (int i = 0; i < object.preSplitDist.length; i++) {
+ output.writeDouble(object.preSplitDist[i], PRECISION, true);
+ }
+ }
+
+ @Override
+ public ComputeContentEvent read(Kryo kryo, Input input,
+ Class<ComputeContentEvent> type) {
+ long splitId = input.readLong(true);
+ long learningNodeId = input.readLong(true);
+
+ int dataLength = input.readInt(true);
+ double[] preSplitDist = new double[dataLength];
+
+ for (int i = 0; i < dataLength; i++) {
+ preSplitDist[i] = input.readDouble(PRECISION, true);
+ }
+
+ return new ComputeContentEvent(splitId, learningNodeId, preSplitDist);
+ }
+ }
+
+ /**
+ * The Kryo serializer class for ComputeContentEevent when executing on top of
+ * Storm with full precision of the statistics.
+ *
+ * @author Arinto Murdopo
+ *
+ */
+ public static final class ComputeCEFullPrecSerializer extends Serializer<ComputeContentEvent> {
+
+ @Override
+ public void write(Kryo kryo, Output output, ComputeContentEvent object) {
+ output.writeLong(object.splitId, true);
+ output.writeLong(object.learningNodeId, true);
+
+ output.writeInt(object.preSplitDist.length, true);
+ for (int i = 0; i < object.preSplitDist.length; i++) {
+ output.writeDouble(object.preSplitDist[i]);
+ }
+ }
+
+ @Override
+ public ComputeContentEvent read(Kryo kryo, Input input,
+ Class<ComputeContentEvent> type) {
+ long splitId = input.readLong(true);
+ long learningNodeId = input.readLong(true);
+
+ int dataLength = input.readInt(true);
+ double[] preSplitDist = new double[dataLength];
+
+ for (int i = 0; i < dataLength; i++) {
+ preSplitDist[i] = input.readDouble();
+ }
+
+ return new ComputeContentEvent(splitId, learningNodeId, preSplitDist);
+ }
+
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ControlContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ControlContentEvent.java
index 201ef88..e7b8653 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ControlContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ControlContentEvent.java
@@ -23,49 +23,51 @@
import com.yahoo.labs.samoa.core.ContentEvent;
/**
- * Abstract class to represent ContentEvent to control Local Statistic Processor.
+ * Abstract class to represent ContentEvent to control Local Statistic
+ * Processor.
+ *
* @author Arinto Murdopo
- *
+ *
*/
abstract class ControlContentEvent implements ContentEvent {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 5837375639629708363L;
+ private static final long serialVersionUID = 5837375639629708363L;
- protected final long learningNodeId;
-
- public ControlContentEvent(){
- this.learningNodeId = -1;
- }
-
- ControlContentEvent(long id){
- this.learningNodeId = id;
- }
-
- @Override
- public final String getKey() {
- return null;
- }
-
- @Override
- public void setKey(String str){
- //Do nothing
- }
-
- @Override
- public boolean isLastEvent(){
- return false;
- }
-
- final long getLearningNodeId(){
- return this.learningNodeId;
- }
-
- abstract LocStatControl getType();
-
- static enum LocStatControl {
- COMPUTE, DELETE
- }
+ protected final long learningNodeId;
+
+ public ControlContentEvent() {
+ this.learningNodeId = -1;
+ }
+
+ ControlContentEvent(long id) {
+ this.learningNodeId = id;
+ }
+
+ @Override
+ public final String getKey() {
+ return null;
+ }
+
+ @Override
+ public void setKey(String str) {
+ // Do nothing
+ }
+
+ @Override
+ public boolean isLastEvent() {
+ return false;
+ }
+
+ final long getLearningNodeId() {
+ return this.learningNodeId;
+ }
+
+ abstract LocStatControl getType();
+
+ static enum LocStatControl {
+ COMPUTE, DELETE
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/DeleteContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/DeleteContentEvent.java
index c721255..52631bd 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/DeleteContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/DeleteContentEvent.java
@@ -21,25 +21,27 @@
*/
/**
- * Delete Content Event is the content event that is sent by Model Aggregator Processor
- * to delete unnecessary statistic in Local Statistic Processor.
+ * Delete Content Event is the content event that is sent by Model Aggregator
+ * Processor to delete unnecessary statistic in Local Statistic Processor.
+ *
* @author Arinto Murdopo
- *
+ *
*/
final class DeleteContentEvent extends ControlContentEvent {
- private static final long serialVersionUID = -2105250722560863633L;
+ private static final long serialVersionUID = -2105250722560863633L;
- public DeleteContentEvent(){
- super(-1);
- }
-
- DeleteContentEvent(long id) {
- super(id); }
+ public DeleteContentEvent() {
+ super(-1);
+ }
- @Override
- LocStatControl getType() {
- return LocStatControl.DELETE;
- }
+ DeleteContentEvent(long id) {
+ super(id);
+ }
+
+ @Override
+ LocStatControl getType() {
+ return LocStatControl.DELETE;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FilterProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FilterProcessor.java
index b6a73c6..d7f2cae 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FilterProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FilterProcessor.java
@@ -20,7 +20,6 @@
* #L%
*/
-
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -37,155 +36,159 @@
import java.util.List;
/**
- * Filter Processor that stores and filters the instances before
- * sending them to the Model Aggregator Processor.
-
+ * Filter Processor that stores and filters the instances before sending them to
+ * the Model Aggregator Processor.
+ *
* @author Arinto Murdopo
- *
+ *
*/
final class FilterProcessor implements Processor {
- private static final long serialVersionUID = -1685875718300564885L;
- private static final Logger logger = LoggerFactory.getLogger(FilterProcessor.class);
+ private static final long serialVersionUID = -1685875718300564885L;
+ private static final Logger logger = LoggerFactory.getLogger(FilterProcessor.class);
- private int processorId;
-
- private final Instances dataset;
- private InstancesHeader modelContext;
-
- //available streams
- private Stream outputStream;
-
- //private constructor based on Builder pattern
- private FilterProcessor(Builder builder){
- this.dataset = builder.dataset;
- this.batchSize = builder.batchSize;
- this.delay = builder.delay;
- }
-
- private int waitingInstances = 0;
-
- private int delay = 0;
-
- private int batchSize = 200;
-
- private List<InstanceContentEvent> contentEventList = new LinkedList<InstanceContentEvent>();
-
- @Override
- public boolean process(ContentEvent event) {
- //Receive a new instance from source
- if(event instanceof InstanceContentEvent){
- InstanceContentEvent instanceContentEvent = (InstanceContentEvent) event;
- this.contentEventList.add(instanceContentEvent);
- this.waitingInstances++;
- if (this.waitingInstances == this.batchSize || instanceContentEvent.isLastEvent()){
- //Send Instances
- InstancesContentEvent outputEvent = new InstancesContentEvent(instanceContentEvent);
- boolean isLastEvent = false;
- while (!this.contentEventList.isEmpty()){
- InstanceContentEvent ice = this.contentEventList.remove(0);
- Instance inst = ice.getInstance();
- outputEvent.add(inst);
- if (!isLastEvent) {
- isLastEvent = ice.isLastEvent();
- }
- }
- outputEvent.setLast(isLastEvent);
- this.waitingInstances = 0;
- this.outputStream.put(outputEvent);
- if (this.delay > 0) {
- try {
- Thread.sleep(this.delay);
- } catch(InterruptedException ex) {
- Thread.currentThread().interrupt();
- }
- }
- }
- }
- return false;
- }
-
- @Override
- public void onCreate(int id) {
- this.processorId = id;
- this.waitingInstances = 0;
-
- }
+ private int processorId;
- @Override
- public Processor newProcessor(Processor p) {
- FilterProcessor oldProcessor = (FilterProcessor)p;
- FilterProcessor newProcessor =
- new FilterProcessor.Builder(oldProcessor).build();
-
- newProcessor.setOutputStream(oldProcessor.outputStream);
- return newProcessor;
- }
-
- @Override
- public String toString() {
- StringBuilder sb = new StringBuilder();
- sb.append(super.toString());
- return sb.toString();
- }
-
- void setOutputStream(Stream outputStream){
- this.outputStream = outputStream;
- }
-
-
- /**
- * Helper method to generate new ResultContentEvent based on an instance and
- * its prediction result.
- * @param prediction The predicted class label from the decision tree model.
- * @param inEvent The associated instance content event
- * @return ResultContentEvent to be sent into Evaluator PI or other destination PI.
- */
- private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){
- ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent());
- rce.setClassifierIndex(this.processorId);
- rce.setEvaluationIndex(inEvent.getEvaluationIndex());
- return rce;
- }
-
-
- /**
- * Builder class to replace constructors with many parameters
- * @author Arinto Murdopo
- *
- */
- static class Builder{
-
- //required parameters
- private final Instances dataset;
-
- private int delay = 0;
-
+ private final Instances dataset;
+ private InstancesHeader modelContext;
+
+ // available streams
+ private Stream outputStream;
+
+ // private constructor based on Builder pattern
+ private FilterProcessor(Builder builder) {
+ this.dataset = builder.dataset;
+ this.batchSize = builder.batchSize;
+ this.delay = builder.delay;
+ }
+
+ private int waitingInstances = 0;
+
+ private int delay = 0;
+
+ private int batchSize = 200;
+
+ private List<InstanceContentEvent> contentEventList = new LinkedList<InstanceContentEvent>();
+
+ @Override
+ public boolean process(ContentEvent event) {
+ // Receive a new instance from source
+ if (event instanceof InstanceContentEvent) {
+ InstanceContentEvent instanceContentEvent = (InstanceContentEvent) event;
+ this.contentEventList.add(instanceContentEvent);
+ this.waitingInstances++;
+ if (this.waitingInstances == this.batchSize || instanceContentEvent.isLastEvent()) {
+ // Send Instances
+ InstancesContentEvent outputEvent = new InstancesContentEvent(instanceContentEvent);
+ boolean isLastEvent = false;
+ while (!this.contentEventList.isEmpty()) {
+ InstanceContentEvent ice = this.contentEventList.remove(0);
+ Instance inst = ice.getInstance();
+ outputEvent.add(inst);
+ if (!isLastEvent) {
+ isLastEvent = ice.isLastEvent();
+ }
+ }
+ outputEvent.setLast(isLastEvent);
+ this.waitingInstances = 0;
+ this.outputStream.put(outputEvent);
+ if (this.delay > 0) {
+ try {
+ Thread.sleep(this.delay);
+ } catch (InterruptedException ex) {
+ Thread.currentThread().interrupt();
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public void onCreate(int id) {
+ this.processorId = id;
+ this.waitingInstances = 0;
+
+ }
+
+ @Override
+ public Processor newProcessor(Processor p) {
+ FilterProcessor oldProcessor = (FilterProcessor) p;
+ FilterProcessor newProcessor =
+ new FilterProcessor.Builder(oldProcessor).build();
+
+ newProcessor.setOutputStream(oldProcessor.outputStream);
+ return newProcessor;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(super.toString());
+ return sb.toString();
+ }
+
+ void setOutputStream(Stream outputStream) {
+ this.outputStream = outputStream;
+ }
+
+ /**
+ * Helper method to generate new ResultContentEvent based on an instance and
+ * its prediction result.
+ *
+ * @param prediction
+ * The predicted class label from the decision tree model.
+ * @param inEvent
+ * The associated instance content event
+ * @return ResultContentEvent to be sent into Evaluator PI or other
+ * destination PI.
+ */
+ private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
+ ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
+ inEvent.getClassId(), prediction, inEvent.isLastEvent());
+ rce.setClassifierIndex(this.processorId);
+ rce.setEvaluationIndex(inEvent.getEvaluationIndex());
+ return rce;
+ }
+
+ /**
+ * Builder class to replace constructors with many parameters
+ *
+ * @author Arinto Murdopo
+ *
+ */
+ static class Builder {
+
+ // required parameters
+ private final Instances dataset;
+
+ private int delay = 0;
+
private int batchSize = 200;
- Builder(Instances dataset){
- this.dataset = dataset;
- }
-
- Builder(FilterProcessor oldProcessor){
- this.dataset = oldProcessor.dataset;
- this.delay = oldProcessor.delay;
- this.batchSize = oldProcessor.batchSize;
- }
-
- public Builder delay(int delay){
- this.delay = delay;
- return this;
- }
-
- public Builder batchSize(int val){
- this.batchSize = val;
- return this;
- }
-
- FilterProcessor build(){
- return new FilterProcessor(this);
- }
- }
-
+ Builder(Instances dataset) {
+ this.dataset = dataset;
+ }
+
+ Builder(FilterProcessor oldProcessor) {
+ this.dataset = oldProcessor.dataset;
+ this.delay = oldProcessor.delay;
+ this.batchSize = oldProcessor.batchSize;
+ }
+
+ public Builder delay(int delay) {
+ this.delay = delay;
+ return this;
+ }
+
+ public Builder batchSize(int val) {
+ this.batchSize = val;
+ return this;
+ }
+
+ FilterProcessor build() {
+ return new FilterProcessor(this);
+ }
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FoundNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FoundNode.java
index 4123ea5..ba1c602 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FoundNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FoundNode.java
@@ -21,57 +21,57 @@
*/
/**
- * Class that represents the necessary data structure of the node where an instance
- * is routed/filtered through the decision tree model.
+ * Class that represents the necessary data structure of the node where an
+ * instance is routed/filtered through the decision tree model.
*
* @author Arinto Murdopo
- *
+ *
*/
-final class FoundNode implements java.io.Serializable{
-
- /**
+final class FoundNode implements java.io.Serializable {
+
+ /**
*
*/
- private static final long serialVersionUID = -637695387934143293L;
-
- private final Node node;
- private final SplitNode parent;
- private final int parentBranch;
-
- FoundNode(Node node, SplitNode splitNode, int parentBranch){
- this.node = node;
- this.parent = splitNode;
- this.parentBranch = parentBranch;
- }
-
- /**
- * Method to get the node where an instance is routed/filtered through the decision tree
- * model for testing and training.
- *
- * @return The node where the instance is routed/filtered
- */
- Node getNode(){
- return this.node;
- }
-
- /**
- * Method to get the parent of the node where an instance is routed/filtered through the decision tree
- * model for testing and training
- *
- * @return The parent of the node
- */
- SplitNode getParent(){
- return this.parent;
- }
-
- /**
- * Method to get the index of the node (where an instance is routed/filtered through the decision tree
- * model for testing and training) in its parent.
- *
- * @return The index of the node in its parent node.
- */
- int getParentBranch(){
- return this.parentBranch;
- }
-
+ private static final long serialVersionUID = -637695387934143293L;
+
+ private final Node node;
+ private final SplitNode parent;
+ private final int parentBranch;
+
+ FoundNode(Node node, SplitNode splitNode, int parentBranch) {
+ this.node = node;
+ this.parent = splitNode;
+ this.parentBranch = parentBranch;
+ }
+
+ /**
+ * Method to get the node where an instance is routed/filtered through the
+ * decision tree model for testing and training.
+ *
+ * @return The node where the instance is routed/filtered
+ */
+ Node getNode() {
+ return this.node;
+ }
+
+ /**
+ * Method to get the parent of the node where an instance is routed/filtered
+ * through the decision tree model for testing and training
+ *
+ * @return The parent of the node
+ */
+ SplitNode getParent() {
+ return this.parent;
+ }
+
+ /**
+ * Method to get the index of the node (where an instance is routed/filtered
+ * through the decision tree model for testing and training) in its parent.
+ *
+ * @return The index of the node in its parent node.
+ */
+ int getParentBranch() {
+ return this.parentBranch;
+ }
+
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/InactiveLearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/InactiveLearningNode.java
index 82a05de..d979d1b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/InactiveLearningNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/InactiveLearningNode.java
@@ -23,34 +23,33 @@
import com.yahoo.labs.samoa.instances.Instance;
/**
- * Class that represents inactive learning node. Inactive learning node is
- * a node which only keeps track of the observed class distribution. It does
- * not store the statistic for splitting the node.
+ * Class that represents inactive learning node. Inactive learning node is a
+ * node which only keeps track of the observed class distribution. It does not
+ * store the statistic for splitting the node.
*
* @author Arinto Murdopo
- *
+ *
*/
final class InactiveLearningNode extends LearningNode {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -814552382883472302L;
-
-
- InactiveLearningNode(double[] initialClassObservation) {
- super(initialClassObservation);
- }
+ private static final long serialVersionUID = -814552382883472302L;
- @Override
- void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) {
- this.observedClassDistribution.addToValue(
- (int)inst.classValue(), inst.weight());
- }
+ InactiveLearningNode(double[] initialClassObservation) {
+ super(initialClassObservation);
+ }
- @Override
- double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) {
- return this.observedClassDistribution.getArrayCopy();
- }
+ @Override
+ void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) {
+ this.observedClassDistribution.addToValue(
+ (int) inst.classValue(), inst.weight());
+ }
+
+ @Override
+ double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) {
+ return this.observedClassDistribution.getArrayCopy();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LearningNode.java
index 58de671..8014ce7 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LearningNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LearningNode.java
@@ -24,32 +24,36 @@
/**
* Abstract class that represents a learning node
+ *
* @author Arinto Murdopo
- *
+ *
*/
abstract class LearningNode extends Node {
- private static final long serialVersionUID = 7157319356146764960L;
-
- protected LearningNode(double[] classObservation) {
- super(classObservation);
- }
-
- /**
- * Method to process the instance for learning
- * @param inst The processed instance
- * @param proc The model aggregator processor where this learning node exists
- */
- abstract void learnFromInstance(Instance inst, ModelAggregatorProcessor proc);
-
- @Override
- protected boolean isLeaf(){
- return true;
- }
-
- @Override
- protected FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent,
- int parentBranch) {
- return new FoundNode(this, parent, parentBranch);
- }
+ private static final long serialVersionUID = 7157319356146764960L;
+
+ protected LearningNode(double[] classObservation) {
+ super(classObservation);
+ }
+
+ /**
+ * Method to process the instance for learning
+ *
+ * @param inst
+ * The processed instance
+ * @param proc
+ * The model aggregator processor where this learning node exists
+ */
+ abstract void learnFromInstance(Instance inst, ModelAggregatorProcessor proc);
+
+ @Override
+ protected boolean isLeaf() {
+ return true;
+ }
+
+ @Override
+ protected FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent,
+ int parentBranch) {
+ return new FoundNode(this, parent, parentBranch);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalResultContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalResultContentEvent.java
index 142d28a..10ea055 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalResultContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalResultContentEvent.java
@@ -24,69 +24,74 @@
import com.yahoo.labs.samoa.core.ContentEvent;
/**
- * Local Result Content Event is the content event that represents local
+ * Local Result Content Event is the content event that represents local
* calculation of statistic in Local Statistic Processor.
*
* @author Arinto Murdopo
- *
+ *
*/
-final class LocalResultContentEvent implements ContentEvent{
-
- private static final long serialVersionUID = -4206620993777418571L;
-
- private final AttributeSplitSuggestion bestSuggestion;
- private final AttributeSplitSuggestion secondBestSuggestion;
- private final long splitId;
-
- public LocalResultContentEvent(){
- bestSuggestion = null;
- secondBestSuggestion = null;
- splitId = -1;
- }
-
- LocalResultContentEvent(long splitId, AttributeSplitSuggestion best, AttributeSplitSuggestion secondBest){
- this.splitId = splitId;
- this.bestSuggestion = best;
- this.secondBestSuggestion = secondBest;
- }
-
- @Override
- public String getKey() {
- return null;
- }
-
- /**
- * Method to return the best attribute split suggestion from this local statistic calculation.
- * @return The best attribute split suggestion.
- */
- AttributeSplitSuggestion getBestSuggestion(){
- return this.bestSuggestion;
- }
-
- /**
- * Method to return the second best attribute split suggestion from this local statistic calculation.
- * @return The second best attribute split suggestion.
- */
- AttributeSplitSuggestion getSecondBestSuggestion(){
- return this.secondBestSuggestion;
- }
-
- /**
- * Method to get the split ID of this local statistic calculation result
- * @return The split id of this local calculation result
- */
- long getSplitId(){
- return this.splitId;
- }
+final class LocalResultContentEvent implements ContentEvent {
- @Override
- public void setKey(String str) {
- //do nothing
-
- }
+ private static final long serialVersionUID = -4206620993777418571L;
- @Override
- public boolean isLastEvent() {
- return false;
- }
+ private final AttributeSplitSuggestion bestSuggestion;
+ private final AttributeSplitSuggestion secondBestSuggestion;
+ private final long splitId;
+
+ public LocalResultContentEvent() {
+ bestSuggestion = null;
+ secondBestSuggestion = null;
+ splitId = -1;
+ }
+
+ LocalResultContentEvent(long splitId, AttributeSplitSuggestion best, AttributeSplitSuggestion secondBest) {
+ this.splitId = splitId;
+ this.bestSuggestion = best;
+ this.secondBestSuggestion = secondBest;
+ }
+
+ @Override
+ public String getKey() {
+ return null;
+ }
+
+ /**
+ * Method to return the best attribute split suggestion from this local
+ * statistic calculation.
+ *
+ * @return The best attribute split suggestion.
+ */
+ AttributeSplitSuggestion getBestSuggestion() {
+ return this.bestSuggestion;
+ }
+
+ /**
+ * Method to return the second best attribute split suggestion from this local
+ * statistic calculation.
+ *
+ * @return The second best attribute split suggestion.
+ */
+ AttributeSplitSuggestion getSecondBestSuggestion() {
+ return this.secondBestSuggestion;
+ }
+
+ /**
+ * Method to get the split ID of this local statistic calculation result
+ *
+ * @return The split id of this local calculation result
+ */
+ long getSplitId() {
+ return this.splitId;
+ }
+
+ @Override
+ public void setKey(String str) {
+ // do nothing
+
+ }
+
+ @Override
+ public boolean isLastEvent() {
+ return false;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalStatisticsProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalStatisticsProcessor.java
index 25e5592..a6335e7 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalStatisticsProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalStatisticsProcessor.java
@@ -44,203 +44,203 @@
import com.yahoo.labs.samoa.topology.Stream;
/**
- * Local Statistic Processor contains the local statistic of a subset of the attributes.
+ * Local Statistic Processor contains the local statistic of a subset of the
+ * attributes.
+ *
* @author Arinto Murdopo
- *
+ *
*/
final class LocalStatisticsProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -3967695130634517631L;
- private static Logger logger = LoggerFactory.getLogger(LocalStatisticsProcessor.class);
-
- //Collection of AttributeObservers, for each ActiveLearningNode and AttributeId
- private Table<Long, Integer, AttributeClassObserver> localStats;
-
- private Stream computationResultStream;
-
- private final SplitCriterion splitCriterion;
- private final boolean binarySplit;
- private final AttributeClassObserver nominalClassObserver;
- private final AttributeClassObserver numericClassObserver;
-
- //the two observer classes below are also needed to be setup from the Tree
- private LocalStatisticsProcessor(Builder builder){
- this.splitCriterion = builder.splitCriterion;
- this.binarySplit = builder.binarySplit;
- this.nominalClassObserver = builder.nominalClassObserver;
- this.numericClassObserver = builder.numericClassObserver;
+ private static final long serialVersionUID = -3967695130634517631L;
+ private static Logger logger = LoggerFactory.getLogger(LocalStatisticsProcessor.class);
+
+ // Collection of AttributeObservers, for each ActiveLearningNode and
+ // AttributeId
+ private Table<Long, Integer, AttributeClassObserver> localStats;
+
+ private Stream computationResultStream;
+
+ private final SplitCriterion splitCriterion;
+ private final boolean binarySplit;
+ private final AttributeClassObserver nominalClassObserver;
+ private final AttributeClassObserver numericClassObserver;
+
+ // the two observer classes below are also needed to be setup from the Tree
+ private LocalStatisticsProcessor(Builder builder) {
+ this.splitCriterion = builder.splitCriterion;
+ this.binarySplit = builder.binarySplit;
+ this.nominalClassObserver = builder.nominalClassObserver;
+ this.numericClassObserver = builder.numericClassObserver;
+ }
+
+ @Override
+ public boolean process(ContentEvent event) {
+ // process AttributeContentEvent by updating the subset of local statistics
+ if (event instanceof AttributeBatchContentEvent) {
+ AttributeBatchContentEvent abce = (AttributeBatchContentEvent) event;
+ List<ContentEvent> contentEventList = abce.getContentEventList();
+ for (ContentEvent contentEvent : contentEventList) {
+ AttributeContentEvent ace = (AttributeContentEvent) contentEvent;
+ Long learningNodeId = ace.getLearningNodeId();
+ Integer obsIndex = ace.getObsIndex();
+
+ AttributeClassObserver obs = localStats.get(
+ learningNodeId, obsIndex);
+
+ if (obs == null) {
+ obs = ace.isNominal() ? newNominalClassObserver()
+ : newNumericClassObserver();
+ localStats.put(ace.getLearningNodeId(), obsIndex, obs);
+ }
+ obs.observeAttributeClass(ace.getAttrVal(), ace.getClassVal(),
+ ace.getWeight());
+ }
+
+ /*
+ * if (event instanceof AttributeContentEvent) { AttributeContentEvent ace
+ * = (AttributeContentEvent) event; Long learningNodeId =
+ * Long.valueOf(ace.getLearningNodeId()); Integer obsIndex =
+ * Integer.valueOf(ace.getObsIndex());
+ *
+ * AttributeClassObserver obs = localStats.get( learningNodeId, obsIndex);
+ *
+ * if (obs == null) { obs = ace.isNominal() ? newNominalClassObserver() :
+ * newNumericClassObserver(); localStats.put(ace.getLearningNodeId(),
+ * obsIndex, obs); } obs.observeAttributeClass(ace.getAttrVal(),
+ * ace.getClassVal(), ace.getWeight());
+ */
+ } else if (event instanceof ComputeContentEvent) {
+ // process ComputeContentEvent by calculating the local statistic
+ // and send back the calculation results via computation result stream.
+ ComputeContentEvent cce = (ComputeContentEvent) event;
+ Long learningNodeId = cce.getLearningNodeId();
+ double[] preSplitDist = cce.getPreSplitDist();
+
+ Map<Integer, AttributeClassObserver> learningNodeRowMap = localStats
+ .row(learningNodeId);
+ List<AttributeSplitSuggestion> suggestions = new Vector<>();
+
+ for (Entry<Integer, AttributeClassObserver> entry : learningNodeRowMap.entrySet()) {
+ AttributeClassObserver obs = entry.getValue();
+ AttributeSplitSuggestion suggestion = obs
+ .getBestEvaluatedSplitSuggestion(splitCriterion,
+ preSplitDist, entry.getKey(), binarySplit);
+ if (suggestion != null) {
+ suggestions.add(suggestion);
+ }
+ }
+
+ AttributeSplitSuggestion[] bestSuggestions = suggestions
+ .toArray(new AttributeSplitSuggestion[suggestions.size()]);
+
+ Arrays.sort(bestSuggestions);
+
+ AttributeSplitSuggestion bestSuggestion = null;
+ AttributeSplitSuggestion secondBestSuggestion = null;
+
+ if (bestSuggestions.length >= 1) {
+ bestSuggestion = bestSuggestions[bestSuggestions.length - 1];
+
+ if (bestSuggestions.length >= 2) {
+ secondBestSuggestion = bestSuggestions[bestSuggestions.length - 2];
+ }
+ }
+
+ // create the local result content event
+ LocalResultContentEvent lcre =
+ new LocalResultContentEvent(cce.getSplitId(), bestSuggestion, secondBestSuggestion);
+ computationResultStream.put(lcre);
+ logger.debug("Finish compute event");
+ } else if (event instanceof DeleteContentEvent) {
+ DeleteContentEvent dce = (DeleteContentEvent) event;
+ Long learningNodeId = dce.getLearningNodeId();
+ localStats.rowMap().remove(learningNodeId);
}
-
- @Override
- public boolean process(ContentEvent event) {
- //process AttributeContentEvent by updating the subset of local statistics
- if (event instanceof AttributeBatchContentEvent) {
- AttributeBatchContentEvent abce = (AttributeBatchContentEvent) event;
- List<ContentEvent> contentEventList = abce.getContentEventList();
- for (ContentEvent contentEvent: contentEventList ){
- AttributeContentEvent ace = (AttributeContentEvent) contentEvent;
- Long learningNodeId = ace.getLearningNodeId();
- Integer obsIndex = ace.getObsIndex();
+ return false;
+ }
- AttributeClassObserver obs = localStats.get(
- learningNodeId, obsIndex);
+ @Override
+ public void onCreate(int id) {
+ this.localStats = HashBasedTable.create();
+ }
- if (obs == null) {
- obs = ace.isNominal() ? newNominalClassObserver()
- : newNumericClassObserver();
- localStats.put(ace.getLearningNodeId(), obsIndex, obs);
- }
- obs.observeAttributeClass(ace.getAttrVal(), ace.getClassVal(),
- ace.getWeight());
- }
-
-
- /*if (event instanceof AttributeContentEvent) {
- AttributeContentEvent ace = (AttributeContentEvent) event;
- Long learningNodeId = Long.valueOf(ace.getLearningNodeId());
- Integer obsIndex = Integer.valueOf(ace.getObsIndex());
+ @Override
+ public Processor newProcessor(Processor p) {
+ LocalStatisticsProcessor oldProcessor = (LocalStatisticsProcessor) p;
+ LocalStatisticsProcessor newProcessor = new LocalStatisticsProcessor.Builder(oldProcessor).build();
- AttributeClassObserver obs = localStats.get(
- learningNodeId, obsIndex);
+ newProcessor.setComputationResultStream(oldProcessor.computationResultStream);
- if (obs == null) {
- obs = ace.isNominal() ? newNominalClassObserver()
- : newNumericClassObserver();
- localStats.put(ace.getLearningNodeId(), obsIndex, obs);
- }
- obs.observeAttributeClass(ace.getAttrVal(), ace.getClassVal(),
- ace.getWeight());
- */
- } else if (event instanceof ComputeContentEvent) {
- //process ComputeContentEvent by calculating the local statistic
- //and send back the calculation results via computation result stream.
- ComputeContentEvent cce = (ComputeContentEvent) event;
- Long learningNodeId = cce.getLearningNodeId();
- double[] preSplitDist = cce.getPreSplitDist();
-
- Map<Integer, AttributeClassObserver> learningNodeRowMap = localStats
- .row(learningNodeId);
- List<AttributeSplitSuggestion> suggestions = new Vector<>();
+ return newProcessor;
+ }
- for (Entry<Integer, AttributeClassObserver> entry : learningNodeRowMap.entrySet()) {
- AttributeClassObserver obs = entry.getValue();
- AttributeSplitSuggestion suggestion = obs
- .getBestEvaluatedSplitSuggestion(splitCriterion,
- preSplitDist, entry.getKey(), binarySplit);
- if(suggestion != null){
- suggestions.add(suggestion);
- }
- }
-
- AttributeSplitSuggestion[] bestSuggestions = suggestions
- .toArray(new AttributeSplitSuggestion[suggestions.size()]);
+ /**
+ * Method to set the computation result when using this processor to build a
+ * topology.
+ *
+ * @param computeStream
+ */
+ void setComputationResultStream(Stream computeStream) {
+ this.computationResultStream = computeStream;
+ }
- Arrays.sort(bestSuggestions);
-
- AttributeSplitSuggestion bestSuggestion = null;
- AttributeSplitSuggestion secondBestSuggestion = null;
-
- if (bestSuggestions.length >= 1){
- bestSuggestion = bestSuggestions[bestSuggestions.length - 1];
-
- if(bestSuggestions.length >= 2){
- secondBestSuggestion = bestSuggestions[bestSuggestions.length - 2];
- }
- }
-
- //create the local result content event
- LocalResultContentEvent lcre =
- new LocalResultContentEvent(cce.getSplitId(), bestSuggestion, secondBestSuggestion);
- computationResultStream.put(lcre);
- logger.debug("Finish compute event");
- } else if (event instanceof DeleteContentEvent) {
- DeleteContentEvent dce = (DeleteContentEvent) event;
- Long learningNodeId = dce.getLearningNodeId();
- localStats.rowMap().remove(learningNodeId);
- }
- return false;
- }
+ private AttributeClassObserver newNominalClassObserver() {
+ return (AttributeClassObserver) this.nominalClassObserver.copy();
+ }
- @Override
- public void onCreate(int id) {
- this.localStats = HashBasedTable.create();
- }
+ private AttributeClassObserver newNumericClassObserver() {
+ return (AttributeClassObserver) this.numericClassObserver.copy();
+ }
- @Override
- public Processor newProcessor(Processor p) {
- LocalStatisticsProcessor oldProcessor = (LocalStatisticsProcessor) p;
- LocalStatisticsProcessor newProcessor
- = new LocalStatisticsProcessor.Builder(oldProcessor).build();
-
- newProcessor.setComputationResultStream(oldProcessor.computationResultStream);
-
- return newProcessor;
- }
-
- /**
- * Method to set the computation result when using this processor to build
- * a topology.
- * @param computeStream
- */
- void setComputationResultStream(Stream computeStream){
- this.computationResultStream = computeStream;
- }
-
- private AttributeClassObserver newNominalClassObserver() {
- return (AttributeClassObserver)this.nominalClassObserver.copy();
+ /**
+ * Builder class to replace constructors with many parameters
+ *
+ * @author Arinto Murdopo
+ *
+ */
+ static class Builder {
+
+ private SplitCriterion splitCriterion = new InfoGainSplitCriterion();
+ private boolean binarySplit = false;
+ private AttributeClassObserver nominalClassObserver = new NominalAttributeClassObserver();
+ private AttributeClassObserver numericClassObserver = new GaussianNumericAttributeClassObserver();
+
+ Builder() {
+
}
- private AttributeClassObserver newNumericClassObserver() {
- return (AttributeClassObserver)this.numericClassObserver.copy();
+ Builder(LocalStatisticsProcessor oldProcessor) {
+ this.splitCriterion = oldProcessor.splitCriterion;
+ this.binarySplit = oldProcessor.binarySplit;
}
-
- /**
- * Builder class to replace constructors with many parameters
- * @author Arinto Murdopo
- *
- */
- static class Builder{
-
- private SplitCriterion splitCriterion = new InfoGainSplitCriterion();
- private boolean binarySplit = false;
- private AttributeClassObserver nominalClassObserver = new NominalAttributeClassObserver();
- private AttributeClassObserver numericClassObserver = new GaussianNumericAttributeClassObserver();
-
- Builder(){
-
- }
-
- Builder(LocalStatisticsProcessor oldProcessor){
- this.splitCriterion = oldProcessor.splitCriterion;
- this.binarySplit = oldProcessor.binarySplit;
- }
-
- Builder splitCriterion(SplitCriterion splitCriterion){
- this.splitCriterion = splitCriterion;
- return this;
- }
-
- Builder binarySplit(boolean binarySplit){
- this.binarySplit = binarySplit;
- return this;
- }
-
- Builder nominalClassObserver(AttributeClassObserver nominalClassObserver){
- this.nominalClassObserver = nominalClassObserver;
- return this;
- }
-
- Builder numericClassObserver(AttributeClassObserver numericClassObserver){
- this.numericClassObserver = numericClassObserver;
- return this;
- }
-
- LocalStatisticsProcessor build(){
- return new LocalStatisticsProcessor(this);
- }
+
+ Builder splitCriterion(SplitCriterion splitCriterion) {
+ this.splitCriterion = splitCriterion;
+ return this;
}
+ Builder binarySplit(boolean binarySplit) {
+ this.binarySplit = binarySplit;
+ return this;
+ }
+
+ Builder nominalClassObserver(AttributeClassObserver nominalClassObserver) {
+ this.nominalClassObserver = nominalClassObserver;
+ return this;
+ }
+
+ Builder numericClassObserver(AttributeClassObserver numericClassObserver) {
+ this.numericClassObserver = numericClassObserver;
+ return this;
+ }
+
+ LocalStatisticsProcessor build() {
+ return new LocalStatisticsProcessor(this);
+ }
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java
index cf7a1b3..b526a40 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java
@@ -53,679 +53,706 @@
import static com.yahoo.labs.samoa.moa.core.Utils.maxIndex;
/**
- * Model Aggegator Processor consists of the decision tree model. It connects
- * to local-statistic PI via attribute stream and control stream.
- * Model-aggregator PI sends the split instances via attribute stream and
- * it sends control messages to ask local-statistic PI to perform computation
- * via control stream.
+ * Model Aggegator Processor consists of the decision tree model. It connects to
+ * local-statistic PI via attribute stream and control stream. Model-aggregator
+ * PI sends the split instances via attribute stream and it sends control
+ * messages to ask local-statistic PI to perform computation via control stream.
*
- * Model-aggregator PI sends the classification result via result stream to
- * an evaluator PI for classifier or other destination PI. The calculation
- * results from local statistic arrive to the model-aggregator PI via
- * computation-result stream.
-
+ * Model-aggregator PI sends the classification result via result stream to an
+ * evaluator PI for classifier or other destination PI. The calculation results
+ * from local statistic arrive to the model-aggregator PI via computation-result
+ * stream.
+ *
* @author Arinto Murdopo
- *
+ *
*/
final class ModelAggregatorProcessor implements Processor {
- private static final long serialVersionUID = -1685875718300564886L;
- private static final Logger logger = LoggerFactory.getLogger(ModelAggregatorProcessor.class);
+ private static final long serialVersionUID = -1685875718300564886L;
+ private static final Logger logger = LoggerFactory.getLogger(ModelAggregatorProcessor.class);
- private int processorId;
-
- private Node treeRoot;
-
- private int activeLeafNodeCount;
- private int inactiveLeafNodeCount;
- private int decisionNodeCount;
- private boolean growthAllowed;
-
- private final Instances dataset;
+ private int processorId;
- //to support concurrent split
- private long splitId;
- private ConcurrentMap<Long, SplittingNodeInfo> splittingNodes;
- private BlockingQueue<Long> timedOutSplittingNodes;
-
- //available streams
- private Stream resultStream;
- private Stream attributeStream;
- private Stream controlStream;
-
- private transient ScheduledExecutorService executor;
-
- private final SplitCriterion splitCriterion;
- private final double splitConfidence;
- private final double tieThreshold;
- private final int gracePeriod;
- private final int parallelismHint;
- private final long timeOut;
-
- //private constructor based on Builder pattern
- private ModelAggregatorProcessor(Builder builder){
- this.dataset = builder.dataset;
- this.splitCriterion = builder.splitCriterion;
- this.splitConfidence = builder.splitConfidence;
- this.tieThreshold = builder.tieThreshold;
- this.gracePeriod = builder.gracePeriod;
- this.parallelismHint = builder.parallelismHint;
- this.timeOut = builder.timeOut;
- this.changeDetector = builder.changeDetector;
+ private Node treeRoot;
- InstancesHeader ih = new InstancesHeader(dataset);
- this.setModelContext(ih);
- }
+ private int activeLeafNodeCount;
+ private int inactiveLeafNodeCount;
+ private int decisionNodeCount;
+ private boolean growthAllowed;
- @Override
- public boolean process(ContentEvent event) {
-
- //Poll the blocking queue shared between ModelAggregator and the time-out threads
- Long timedOutSplitId = timedOutSplittingNodes.poll();
- if(timedOutSplitId != null){ //time out has been reached!
- SplittingNodeInfo splittingNode = splittingNodes.get(timedOutSplitId);
- if (splittingNode != null) {
- this.splittingNodes.remove(timedOutSplitId);
- this.continueAttemptToSplit(splittingNode.activeLearningNode,
- splittingNode.foundNode);
-
- }
+ private final Instances dataset;
- }
-
- //Receive a new instance from source
- if(event instanceof InstancesContentEvent){
- InstancesContentEvent instancesEvent = (InstancesContentEvent) event;
- this.processInstanceContentEvent(instancesEvent);
- //Send information to local-statistic PI
- //for each of the nodes
- if (this.foundNodeSet != null){
- for (FoundNode foundNode: this.foundNodeSet ){
- ActiveLearningNode leafNode = (ActiveLearningNode) foundNode.getNode();
- AttributeBatchContentEvent[] abce = leafNode.getAttributeBatchContentEvent();
- if (abce != null) {
- for (int i = 0; i< this.dataset.numAttributes() - 1; i++) {
- this.sendToAttributeStream(abce[i]);
- }
- }
- leafNode.setAttributeBatchContentEvent(null);
- //this.sendToControlStream(event); //split information
- //See if we can ask for splits
- if(!leafNode.isSplitting()){
- double weightSeen = leafNode.getWeightSeen();
- //check whether it is the time for splitting
- if(weightSeen - leafNode.getWeightSeenAtLastSplitEvaluation() >= this.gracePeriod){
- attemptToSplit(leafNode, foundNode);
- }
- }
- }
- }
- this.foundNodeSet = null;
- } else if(event instanceof LocalResultContentEvent){
- LocalResultContentEvent lrce = (LocalResultContentEvent) event;
- Long lrceSplitId = lrce.getSplitId();
- SplittingNodeInfo splittingNodeInfo = splittingNodes.get(lrceSplitId);
-
- if (splittingNodeInfo != null) { // if null, that means
- // activeLearningNode has been
- // removed by timeout thread
- ActiveLearningNode activeLearningNode = splittingNodeInfo.activeLearningNode;
+ // to support concurrent split
+ private long splitId;
+ private ConcurrentMap<Long, SplittingNodeInfo> splittingNodes;
+ private BlockingQueue<Long> timedOutSplittingNodes;
- activeLearningNode.addDistributedSuggestions(
- lrce.getBestSuggestion(),
- lrce.getSecondBestSuggestion());
+ // available streams
+ private Stream resultStream;
+ private Stream attributeStream;
+ private Stream controlStream;
- if (activeLearningNode.isAllSuggestionsCollected()) {
- splittingNodeInfo.scheduledFuture.cancel(false);
- this.splittingNodes.remove(lrceSplitId);
- this.continueAttemptToSplit(activeLearningNode,
- splittingNodeInfo.foundNode);
- }
- }
- }
- return false;
- }
+ private transient ScheduledExecutorService executor;
- protected Set<FoundNode> foundNodeSet;
-
- @Override
- public void onCreate(int id) {
- this.processorId = id;
-
- this.activeLeafNodeCount = 0;
- this.inactiveLeafNodeCount = 0;
- this.decisionNodeCount = 0;
- this.growthAllowed = true;
-
- this.splittingNodes = new ConcurrentHashMap<>();
- this.timedOutSplittingNodes = new LinkedBlockingQueue<>();
- this.splitId = 0;
-
- //Executor for scheduling time-out threads
- this.executor = Executors.newScheduledThreadPool(8);
- }
+ private final SplitCriterion splitCriterion;
+ private final double splitConfidence;
+ private final double tieThreshold;
+ private final int gracePeriod;
+ private final int parallelismHint;
+ private final long timeOut;
- @Override
- public Processor newProcessor(Processor p) {
- ModelAggregatorProcessor oldProcessor = (ModelAggregatorProcessor)p;
- ModelAggregatorProcessor newProcessor =
- new ModelAggregatorProcessor.Builder(oldProcessor).build();
-
- newProcessor.setResultStream(oldProcessor.resultStream);
- newProcessor.setAttributeStream(oldProcessor.attributeStream);
- newProcessor.setControlStream(oldProcessor.controlStream);
- return newProcessor;
- }
-
- @Override
- public String toString() {
- StringBuilder sb = new StringBuilder();
- sb.append(super.toString());
+ // private constructor based on Builder pattern
+ private ModelAggregatorProcessor(Builder builder) {
+ this.dataset = builder.dataset;
+ this.splitCriterion = builder.splitCriterion;
+ this.splitConfidence = builder.splitConfidence;
+ this.tieThreshold = builder.tieThreshold;
+ this.gracePeriod = builder.gracePeriod;
+ this.parallelismHint = builder.parallelismHint;
+ this.timeOut = builder.timeOut;
+ this.changeDetector = builder.changeDetector;
- sb.append("ActiveLeafNodeCount: ").append(activeLeafNodeCount);
- sb.append("InactiveLeafNodeCount: ").append(inactiveLeafNodeCount);
- sb.append("DecisionNodeCount: ").append(decisionNodeCount);
- sb.append("Growth allowed: ").append(growthAllowed);
- return sb.toString();
- }
-
- void setResultStream(Stream resultStream){
- this.resultStream = resultStream;
- }
-
- void setAttributeStream(Stream attributeStream){
- this.attributeStream = attributeStream;
- }
-
- void setControlStream(Stream controlStream){
- this.controlStream = controlStream;
- }
-
- void sendToAttributeStream(ContentEvent event){
- this.attributeStream.put(event);
- }
-
- void sendToControlStream(ContentEvent event){
- this.controlStream.put(event);
- }
-
- /**
- * Helper method to generate new ResultContentEvent based on an instance and
- * its prediction result.
- * @param prediction The predicted class label from the decision tree model.
- * @param inEvent The associated instance content event
- * @return ResultContentEvent to be sent into Evaluator PI or other destination PI.
- */
- private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){
- ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
- inEvent.getClassId(), prediction, inEvent.isLastEvent());
- rce.setClassifierIndex(this.processorId);
- rce.setEvaluationIndex(inEvent.getEvaluationIndex());
- return rce;
- }
-
- private ResultContentEvent newResultContentEvent(double[] prediction, Instance inst, InstancesContentEvent inEvent){
- ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inst, (int) inst.classValue(), prediction, inEvent.isLastEvent());
- rce.setClassifierIndex(this.processorId);
- rce.setEvaluationIndex(inEvent.getEvaluationIndex());
- return rce;
- }
-
- private List<InstancesContentEvent> contentEventList = new LinkedList<>();
+ InstancesHeader ih = new InstancesHeader(dataset);
+ this.setModelContext(ih);
+ }
-
- /**
- * Helper method to process the InstanceContentEvent
- * @param instContentEvent
- */
- private void processInstanceContentEvent(InstancesContentEvent instContentEvent){
- this.numBatches++;
- this.contentEventList.add(instContentEvent);
- if (this.numBatches == 1 || this.numBatches > 4){
- this.processInstances(this.contentEventList.remove(0));
- }
+ @Override
+ public boolean process(ContentEvent event) {
- if (instContentEvent.isLastEvent()) {
- // drain remaining instances
- while (!contentEventList.isEmpty()) {
- processInstances(contentEventList.remove(0));
- }
- }
-
- }
-
- private int numBatches = 0;
-
- private void processInstances(InstancesContentEvent instContentEvent){
-
- Instance[] instances = instContentEvent.getInstances();
- boolean isTesting = instContentEvent.isTesting();
- boolean isTraining= instContentEvent.isTraining();
- for (Instance inst: instances){
- this.processInstance(inst,instContentEvent, isTesting, isTraining);
- }
- }
-
- private void processInstance(Instance inst, InstancesContentEvent instContentEvent, boolean isTesting, boolean isTraining){
- inst.setDataset(this.dataset);
- //Check the instance whether it is used for testing or training
- //boolean testAndTrain = isTraining; //Train after testing
- double[] prediction = null;
- if (isTesting) {
- prediction = getVotesForInstance(inst, false);
- this.resultStream.put(newResultContentEvent(prediction, inst,
- instContentEvent));
- }
+ // Poll the blocking queue shared between ModelAggregator and the time-out
+ // threads
+ Long timedOutSplitId = timedOutSplittingNodes.poll();
+ if (timedOutSplitId != null) { // time out has been reached!
+ SplittingNodeInfo splittingNode = splittingNodes.get(timedOutSplitId);
+ if (splittingNode != null) {
+ this.splittingNodes.remove(timedOutSplitId);
+ this.continueAttemptToSplit(splittingNode.activeLearningNode,
+ splittingNode.foundNode);
- if (isTraining) {
- trainOnInstanceImpl(inst);
- if (this.changeDetector != null) {
- if (prediction == null) {
- prediction = getVotesForInstance(inst);
- }
- boolean correctlyClassifies = this.correctlyClassifies(inst,prediction);
- double oldEstimation = this.changeDetector.getEstimation();
- this.changeDetector.input(correctlyClassifies ? 0 : 1);
- if (this.changeDetector.getEstimation() > oldEstimation) {
- //Start a new classifier
- logger.info("Change detected, resetting the classifier");
- this.resetLearning();
- this.changeDetector.resetLearning();
- }
- }
- }
- }
-
- private boolean correctlyClassifies(Instance inst, double[] prediction) {
- return maxIndex(prediction) == (int) inst.classValue();
- }
-
- private void resetLearning() {
- this.treeRoot = null;
- //Remove nodes
- FoundNode[] learningNodes = findNodes();
- for (FoundNode learningNode : learningNodes) {
- Node node = learningNode.getNode();
- if (node instanceof SplitNode) {
- SplitNode splitNode;
- splitNode = (SplitNode) node;
- for (int i = 0; i < splitNode.numChildren(); i++) {
- splitNode.setChild(i, null);
- }
- }
- }
- }
-
- protected FoundNode[] findNodes() {
- List<FoundNode> foundList = new LinkedList<>();
- findNodes(this.treeRoot, null, -1, foundList);
- return foundList.toArray(new FoundNode[foundList.size()]);
+ }
+
}
- protected void findNodes(Node node, SplitNode parent,
- int parentBranch, List<FoundNode> found) {
- if (node != null) {
- found.add(new FoundNode(node, parent, parentBranch));
- if (node instanceof SplitNode) {
- SplitNode splitNode = (SplitNode) node;
- for (int i = 0; i < splitNode.numChildren(); i++) {
- findNodes(splitNode.getChild(i), splitNode, i,
- found);
- }
+ // Receive a new instance from source
+ if (event instanceof InstancesContentEvent) {
+ InstancesContentEvent instancesEvent = (InstancesContentEvent) event;
+ this.processInstanceContentEvent(instancesEvent);
+ // Send information to local-statistic PI
+ // for each of the nodes
+ if (this.foundNodeSet != null) {
+ for (FoundNode foundNode : this.foundNodeSet) {
+ ActiveLearningNode leafNode = (ActiveLearningNode) foundNode.getNode();
+ AttributeBatchContentEvent[] abce = leafNode.getAttributeBatchContentEvent();
+ if (abce != null) {
+ for (int i = 0; i < this.dataset.numAttributes() - 1; i++) {
+ this.sendToAttributeStream(abce[i]);
}
+ }
+ leafNode.setAttributeBatchContentEvent(null);
+ // this.sendToControlStream(event); //split information
+ // See if we can ask for splits
+ if (!leafNode.isSplitting()) {
+ double weightSeen = leafNode.getWeightSeen();
+ // check whether it is the time for splitting
+ if (weightSeen - leafNode.getWeightSeenAtLastSplitEvaluation() >= this.gracePeriod) {
+ attemptToSplit(leafNode, foundNode);
+ }
+ }
}
+ }
+ this.foundNodeSet = null;
+ } else if (event instanceof LocalResultContentEvent) {
+ LocalResultContentEvent lrce = (LocalResultContentEvent) event;
+ Long lrceSplitId = lrce.getSplitId();
+ SplittingNodeInfo splittingNodeInfo = splittingNodes.get(lrceSplitId);
+
+ if (splittingNodeInfo != null) { // if null, that means
+ // activeLearningNode has been
+ // removed by timeout thread
+ ActiveLearningNode activeLearningNode = splittingNodeInfo.activeLearningNode;
+
+ activeLearningNode.addDistributedSuggestions(
+ lrce.getBestSuggestion(),
+ lrce.getSecondBestSuggestion());
+
+ if (activeLearningNode.isAllSuggestionsCollected()) {
+ splittingNodeInfo.scheduledFuture.cancel(false);
+ this.splittingNodes.remove(lrceSplitId);
+ this.continueAttemptToSplit(activeLearningNode,
+ splittingNodeInfo.foundNode);
+ }
+ }
+ }
+ return false;
+ }
+
+ protected Set<FoundNode> foundNodeSet;
+
+ @Override
+ public void onCreate(int id) {
+ this.processorId = id;
+
+ this.activeLeafNodeCount = 0;
+ this.inactiveLeafNodeCount = 0;
+ this.decisionNodeCount = 0;
+ this.growthAllowed = true;
+
+ this.splittingNodes = new ConcurrentHashMap<>();
+ this.timedOutSplittingNodes = new LinkedBlockingQueue<>();
+ this.splitId = 0;
+
+ // Executor for scheduling time-out threads
+ this.executor = Executors.newScheduledThreadPool(8);
+ }
+
+ @Override
+ public Processor newProcessor(Processor p) {
+ ModelAggregatorProcessor oldProcessor = (ModelAggregatorProcessor) p;
+ ModelAggregatorProcessor newProcessor =
+ new ModelAggregatorProcessor.Builder(oldProcessor).build();
+
+ newProcessor.setResultStream(oldProcessor.resultStream);
+ newProcessor.setAttributeStream(oldProcessor.attributeStream);
+ newProcessor.setControlStream(oldProcessor.controlStream);
+ return newProcessor;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(super.toString());
+
+ sb.append("ActiveLeafNodeCount: ").append(activeLeafNodeCount);
+ sb.append("InactiveLeafNodeCount: ").append(inactiveLeafNodeCount);
+ sb.append("DecisionNodeCount: ").append(decisionNodeCount);
+ sb.append("Growth allowed: ").append(growthAllowed);
+ return sb.toString();
+ }
+
+ void setResultStream(Stream resultStream) {
+ this.resultStream = resultStream;
+ }
+
+ void setAttributeStream(Stream attributeStream) {
+ this.attributeStream = attributeStream;
+ }
+
+ void setControlStream(Stream controlStream) {
+ this.controlStream = controlStream;
+ }
+
+ void sendToAttributeStream(ContentEvent event) {
+ this.attributeStream.put(event);
+ }
+
+ void sendToControlStream(ContentEvent event) {
+ this.controlStream.put(event);
+ }
+
+ /**
+ * Helper method to generate new ResultContentEvent based on an instance and
+ * its prediction result.
+ *
+ * @param prediction
+ * The predicted class label from the decision tree model.
+ * @param inEvent
+ * The associated instance content event
+ * @return ResultContentEvent to be sent into Evaluator PI or other
+ * destination PI.
+ */
+ private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) {
+ ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(),
+ inEvent.getClassId(), prediction, inEvent.isLastEvent());
+ rce.setClassifierIndex(this.processorId);
+ rce.setEvaluationIndex(inEvent.getEvaluationIndex());
+ return rce;
+ }
+
+ private ResultContentEvent newResultContentEvent(double[] prediction, Instance inst, InstancesContentEvent inEvent) {
+ ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inst, (int) inst.classValue(),
+ prediction, inEvent.isLastEvent());
+ rce.setClassifierIndex(this.processorId);
+ rce.setEvaluationIndex(inEvent.getEvaluationIndex());
+ return rce;
+ }
+
+ private List<InstancesContentEvent> contentEventList = new LinkedList<>();
+
+ /**
+ * Helper method to process the InstanceContentEvent
+ *
+ * @param instContentEvent
+ */
+ private void processInstanceContentEvent(InstancesContentEvent instContentEvent) {
+ this.numBatches++;
+ this.contentEventList.add(instContentEvent);
+ if (this.numBatches == 1 || this.numBatches > 4) {
+ this.processInstances(this.contentEventList.remove(0));
}
-
- /**
- * Helper method to get the prediction result.
- * The actual prediction result is delegated to the leaf node.
- * @param inst
- * @return
- */
- private double[] getVotesForInstance(Instance inst){
- return getVotesForInstance(inst, false);
+ if (instContentEvent.isLastEvent()) {
+ // drain remaining instances
+ while (!contentEventList.isEmpty()) {
+ processInstances(contentEventList.remove(0));
+ }
+ }
+
+ }
+
+ private int numBatches = 0;
+
+ private void processInstances(InstancesContentEvent instContentEvent) {
+
+ Instance[] instances = instContentEvent.getInstances();
+ boolean isTesting = instContentEvent.isTesting();
+ boolean isTraining = instContentEvent.isTraining();
+ for (Instance inst : instances) {
+ this.processInstance(inst, instContentEvent, isTesting, isTraining);
+ }
+ }
+
+ private void processInstance(Instance inst, InstancesContentEvent instContentEvent, boolean isTesting,
+ boolean isTraining) {
+ inst.setDataset(this.dataset);
+ // Check the instance whether it is used for testing or training
+ // boolean testAndTrain = isTraining; //Train after testing
+ double[] prediction = null;
+ if (isTesting) {
+ prediction = getVotesForInstance(inst, false);
+ this.resultStream.put(newResultContentEvent(prediction, inst,
+ instContentEvent));
+ }
+
+ if (isTraining) {
+ trainOnInstanceImpl(inst);
+ if (this.changeDetector != null) {
+ if (prediction == null) {
+ prediction = getVotesForInstance(inst);
}
-
- private double[] getVotesForInstance(Instance inst, boolean isTraining){
- double[] ret;
- FoundNode foundNode = null;
- if(this.treeRoot != null){
- foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
- Node leafNode = foundNode.getNode();
- if(leafNode == null){
- leafNode = foundNode.getParent();
- }
-
- ret = leafNode.getClassVotes(inst, this);
- } else {
- int numClasses = this.dataset.numClasses();
- ret = new double[numClasses];
-
- }
-
- //Training after testing to speed up the process
- if (isTraining){
- if(this.treeRoot == null){
- this.treeRoot = newLearningNode(this.parallelismHint);
+ boolean correctlyClassifies = this.correctlyClassifies(inst, prediction);
+ double oldEstimation = this.changeDetector.getEstimation();
+ this.changeDetector.input(correctlyClassifies ? 0 : 1);
+ if (this.changeDetector.getEstimation() > oldEstimation) {
+ // Start a new classifier
+ logger.info("Change detected, resetting the classifier");
+ this.resetLearning();
+ this.changeDetector.resetLearning();
+ }
+ }
+ }
+ }
+
+ private boolean correctlyClassifies(Instance inst, double[] prediction) {
+ return maxIndex(prediction) == (int) inst.classValue();
+ }
+
+ private void resetLearning() {
+ this.treeRoot = null;
+ // Remove nodes
+ FoundNode[] learningNodes = findNodes();
+ for (FoundNode learningNode : learningNodes) {
+ Node node = learningNode.getNode();
+ if (node instanceof SplitNode) {
+ SplitNode splitNode;
+ splitNode = (SplitNode) node;
+ for (int i = 0; i < splitNode.numChildren(); i++) {
+ splitNode.setChild(i, null);
+ }
+ }
+ }
+ }
+
+ protected FoundNode[] findNodes() {
+ List<FoundNode> foundList = new LinkedList<>();
+ findNodes(this.treeRoot, null, -1, foundList);
+ return foundList.toArray(new FoundNode[foundList.size()]);
+ }
+
+ protected void findNodes(Node node, SplitNode parent,
+ int parentBranch, List<FoundNode> found) {
+ if (node != null) {
+ found.add(new FoundNode(node, parent, parentBranch));
+ if (node instanceof SplitNode) {
+ SplitNode splitNode = (SplitNode) node;
+ for (int i = 0; i < splitNode.numChildren(); i++) {
+ findNodes(splitNode.getChild(i), splitNode, i,
+ found);
+ }
+ }
+ }
+ }
+
+ /**
+ * Helper method to get the prediction result. The actual prediction result is
+ * delegated to the leaf node.
+ *
+ * @param inst
+ * @return
+ */
+ private double[] getVotesForInstance(Instance inst) {
+ return getVotesForInstance(inst, false);
+ }
+
+ private double[] getVotesForInstance(Instance inst, boolean isTraining) {
+ double[] ret;
+ FoundNode foundNode = null;
+ if (this.treeRoot != null) {
+ foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
+ Node leafNode = foundNode.getNode();
+ if (leafNode == null) {
+ leafNode = foundNode.getParent();
+ }
+
+ ret = leafNode.getClassVotes(inst, this);
+ } else {
+ int numClasses = this.dataset.numClasses();
+ ret = new double[numClasses];
+
+ }
+
+ // Training after testing to speed up the process
+ if (isTraining) {
+ if (this.treeRoot == null) {
+ this.treeRoot = newLearningNode(this.parallelismHint);
this.activeLeafNodeCount = 1;
foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
- }
- trainOnInstanceImpl(foundNode, inst);
- }
- return ret;
- }
-
- /**
- * Helper method that represent training of an instance. Since it is decision tree,
- * this method routes the incoming instance into the correct leaf and then update the
- * statistic on the found leaf.
- * @param inst
- */
- private void trainOnInstanceImpl(Instance inst) {
- if(this.treeRoot == null){
- this.treeRoot = newLearningNode(this.parallelismHint);
- this.activeLeafNodeCount = 1;
-
- }
- FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
- trainOnInstanceImpl(foundNode, inst);
+ }
+ trainOnInstanceImpl(foundNode, inst);
+ }
+ return ret;
+ }
+
+ /**
+ * Helper method that represent training of an instance. Since it is decision
+ * tree, this method routes the incoming instance into the correct leaf and
+ * then update the statistic on the found leaf.
+ *
+ * @param inst
+ */
+ private void trainOnInstanceImpl(Instance inst) {
+ if (this.treeRoot == null) {
+ this.treeRoot = newLearningNode(this.parallelismHint);
+ this.activeLeafNodeCount = 1;
+
+ }
+ FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1);
+ trainOnInstanceImpl(foundNode, inst);
+ }
+
+ private void trainOnInstanceImpl(FoundNode foundNode, Instance inst) {
+
+ Node leafNode = foundNode.getNode();
+
+ if (leafNode == null) {
+ leafNode = newLearningNode(this.parallelismHint);
+ foundNode.getParent().setChild(foundNode.getParentBranch(), leafNode);
+ activeLeafNodeCount++;
+ }
+
+ if (leafNode instanceof LearningNode) {
+ LearningNode learningNode = (LearningNode) leafNode;
+ learningNode.learnFromInstance(inst, this);
+ }
+ if (this.foundNodeSet == null) {
+ this.foundNodeSet = new HashSet<>();
+ }
+ this.foundNodeSet.add(foundNode);
+ }
+
+ /**
+ * Helper method to represent a split attempt
+ *
+ * @param activeLearningNode
+ * The corresponding active learning node which will be split
+ * @param foundNode
+ * The data structure to represents the filtering of the instance
+ * using the tree model.
+ */
+ private void attemptToSplit(ActiveLearningNode activeLearningNode, FoundNode foundNode) {
+ // Increment the split ID
+ this.splitId++;
+
+ // Schedule time-out thread
+ ScheduledFuture<?> timeOutHandler = this.executor.schedule(new AggregationTimeOutHandler(this.splitId,
+ this.timedOutSplittingNodes),
+ this.timeOut, TimeUnit.SECONDS);
+
+ // Keep track of the splitting node information, so that we can continue the
+ // split
+ // once we receive all local statistic calculation from Local Statistic PI
+ // this.splittingNodes.put(Long.valueOf(this.splitId), new
+ // SplittingNodeInfo(activeLearningNode, foundNode, null));
+ this.splittingNodes.put(this.splitId, new SplittingNodeInfo(activeLearningNode, foundNode, timeOutHandler));
+
+ // Inform Local Statistic PI to perform local statistic calculation
+ activeLearningNode.requestDistributedSuggestions(this.splitId, this);
+ }
+
+ /**
+ * Helper method to continue the attempt to split once all local calculation
+ * results are received.
+ *
+ * @param activeLearningNode
+ * The corresponding active learning node which will be split
+ * @param foundNode
+ * The data structure to represents the filtering of the instance
+ * using the tree model.
+ */
+ private void continueAttemptToSplit(ActiveLearningNode activeLearningNode, FoundNode foundNode) {
+ AttributeSplitSuggestion bestSuggestion = activeLearningNode.getDistributedBestSuggestion();
+ AttributeSplitSuggestion secondBestSuggestion = activeLearningNode.getDistributedSecondBestSuggestion();
+
+ // compare with null split
+ double[] preSplitDist = activeLearningNode.getObservedClassDistribution();
+ AttributeSplitSuggestion nullSplit = new AttributeSplitSuggestion(null,
+ new double[0][], this.splitCriterion.getMeritOfSplit(
+ preSplitDist,
+ new double[][] { preSplitDist }));
+
+ if ((bestSuggestion == null) || (nullSplit.compareTo(bestSuggestion) > 0)) {
+ secondBestSuggestion = bestSuggestion;
+ bestSuggestion = nullSplit;
+ } else {
+ if ((secondBestSuggestion == null) || (nullSplit.compareTo(secondBestSuggestion) > 0)) {
+ secondBestSuggestion = nullSplit;
+ }
+ }
+
+ boolean shouldSplit = false;
+
+ if (secondBestSuggestion == null) {
+ shouldSplit = (bestSuggestion != null);
+ } else {
+ double hoeffdingBound = computeHoeffdingBound(
+ this.splitCriterion.getRangeOfMerit(activeLearningNode.getObservedClassDistribution()),
+ this.splitConfidence,
+ activeLearningNode.getWeightSeen());
+
+ if ((bestSuggestion.merit - secondBestSuggestion.merit > hoeffdingBound)
+ || (hoeffdingBound < tieThreshold)) {
+ shouldSplit = true;
+ }
+ // TODO: add poor attributes removal
+ }
+
+ SplitNode parent = foundNode.getParent();
+ int parentBranch = foundNode.getParentBranch();
+
+ // split if the Hoeffding bound condition is satisfied
+ if (shouldSplit) {
+
+ if (bestSuggestion.splitTest != null) {
+ SplitNode newSplit = new SplitNode(bestSuggestion.splitTest, activeLearningNode.getObservedClassDistribution());
+
+ for (int i = 0; i < bestSuggestion.numSplits(); i++) {
+ Node newChild = newLearningNode(bestSuggestion.resultingClassDistributionFromSplit(i), this.parallelismHint);
+ newSplit.setChild(i, newChild);
}
-
- private void trainOnInstanceImpl(FoundNode foundNode, Instance inst) {
-
- Node leafNode = foundNode.getNode();
-
- if(leafNode == null){
- leafNode = newLearningNode(this.parallelismHint);
- foundNode.getParent().setChild(foundNode.getParentBranch(), leafNode);
- activeLeafNodeCount++;
- }
-
- if(leafNode instanceof LearningNode){
- LearningNode learningNode = (LearningNode) leafNode;
- learningNode.learnFromInstance(inst, this);
- }
- if (this.foundNodeSet == null){
- this.foundNodeSet = new HashSet<>();
- }
- this.foundNodeSet.add(foundNode);
- }
-
- /**
- * Helper method to represent a split attempt
- * @param activeLearningNode The corresponding active learning node which will be split
- * @param foundNode The data structure to represents the filtering of the instance using the
- * tree model.
- */
- private void attemptToSplit(ActiveLearningNode activeLearningNode, FoundNode foundNode){
- //Increment the split ID
- this.splitId++;
-
- //Schedule time-out thread
- ScheduledFuture<?> timeOutHandler = this.executor.schedule(new AggregationTimeOutHandler(this.splitId, this.timedOutSplittingNodes),
- this.timeOut, TimeUnit.SECONDS);
-
- //Keep track of the splitting node information, so that we can continue the split
- //once we receive all local statistic calculation from Local Statistic PI
- //this.splittingNodes.put(Long.valueOf(this.splitId), new SplittingNodeInfo(activeLearningNode, foundNode, null));
- this.splittingNodes.put(this.splitId, new SplittingNodeInfo(activeLearningNode, foundNode, timeOutHandler));
-
- //Inform Local Statistic PI to perform local statistic calculation
- activeLearningNode.requestDistributedSuggestions(this.splitId, this);
- }
-
-
- /**
- * Helper method to continue the attempt to split once all local calculation results are received.
- * @param activeLearningNode The corresponding active learning node which will be split
- * @param foundNode The data structure to represents the filtering of the instance using the
- * tree model.
- */
- private void continueAttemptToSplit(ActiveLearningNode activeLearningNode, FoundNode foundNode){
- AttributeSplitSuggestion bestSuggestion = activeLearningNode.getDistributedBestSuggestion();
- AttributeSplitSuggestion secondBestSuggestion = activeLearningNode.getDistributedSecondBestSuggestion();
-
- //compare with null split
- double[] preSplitDist = activeLearningNode.getObservedClassDistribution();
- AttributeSplitSuggestion nullSplit = new AttributeSplitSuggestion(null,
- new double[0][], this.splitCriterion.getMeritOfSplit(
- preSplitDist,
- new double[][]{preSplitDist}));
-
- if((bestSuggestion == null) || (nullSplit.compareTo(bestSuggestion) > 0)){
- secondBestSuggestion = bestSuggestion;
- bestSuggestion = nullSplit;
- }else{
- if((secondBestSuggestion == null) || (nullSplit.compareTo(secondBestSuggestion) > 0)){
- secondBestSuggestion = nullSplit;
- }
- }
-
- boolean shouldSplit = false;
-
- if(secondBestSuggestion == null){
- shouldSplit = (bestSuggestion != null);
- }else{
- double hoeffdingBound = computeHoeffdingBound(
- this.splitCriterion.getRangeOfMerit(activeLearningNode.getObservedClassDistribution()),
- this.splitConfidence,
- activeLearningNode.getWeightSeen());
-
- if((bestSuggestion.merit - secondBestSuggestion.merit > hoeffdingBound)
- || (hoeffdingBound < tieThreshold)) {
- shouldSplit = true;
- }
- //TODO: add poor attributes removal
- }
- SplitNode parent = foundNode.getParent();
- int parentBranch = foundNode.getParentBranch();
-
- //split if the Hoeffding bound condition is satisfied
- if(shouldSplit){
+ this.activeLeafNodeCount--;
+ this.decisionNodeCount++;
+ this.activeLeafNodeCount += bestSuggestion.numSplits();
- if (bestSuggestion.splitTest != null) {
- SplitNode newSplit = new SplitNode(bestSuggestion.splitTest, activeLearningNode.getObservedClassDistribution());
-
- for(int i = 0; i < bestSuggestion.numSplits(); i++){
- Node newChild = newLearningNode(bestSuggestion.resultingClassDistributionFromSplit(i), this.parallelismHint);
- newSplit.setChild(i, newChild);
- }
-
- this.activeLeafNodeCount--;
- this.decisionNodeCount++;
- this.activeLeafNodeCount += bestSuggestion.numSplits();
-
- if(parent == null){
- this.treeRoot = newSplit;
- }else{
- parent.setChild(parentBranch, newSplit);
- }
- }
- //TODO: add check on the model's memory size
- }
-
- //housekeeping
- activeLearningNode.endSplitting();
- activeLearningNode.setWeightSeenAtLastSplitEvaluation(activeLearningNode.getWeightSeen());
- }
-
- /**
- * Helper method to deactivate learning node
- * @param toDeactivate Active Learning Node that will be deactivated
- * @param parent Parent of the soon-to-be-deactivated Active LearningNode
- * @param parentBranch the branch index of the node in the parent node
- */
- private void deactivateLearningNode(ActiveLearningNode toDeactivate, SplitNode parent, int parentBranch){
- Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution());
- if(parent == null){
- this.treeRoot = newLeaf;
- }else{
- parent.setChild(parentBranch, newLeaf);
- }
-
- this.activeLeafNodeCount--;
- this.inactiveLeafNodeCount++;
- }
-
-
- private LearningNode newLearningNode(int parallelismHint){
- return newLearningNode(new double[0], parallelismHint);
- }
-
- private LearningNode newLearningNode(double[] initialClassObservations, int parallelismHint){
- //for VHT optimization, we need to dynamically instantiate the appropriate ActiveLearningNode
- return new ActiveLearningNode(initialClassObservations, parallelismHint);
- }
-
- /**
- * Helper method to set the model context, i.e. how many attributes they are and what is the class index
- * @param ih
- */
- private void setModelContext(InstancesHeader ih){
- //TODO possibly refactored
- if ((ih != null) && (ih.classIndex() < 0)) {
- throw new IllegalArgumentException(
- "Context for a classifier must include a class to learn");
+ if (parent == null) {
+ this.treeRoot = newSplit;
+ } else {
+ parent.setChild(parentBranch, newSplit);
}
- //TODO: check flag for checking whether training has started or not
-
- //model context is used to describe the model
- logger.trace("Model context: {}", ih.toString());
- }
+ }
+ // TODO: add check on the model's memory size
+ }
- private static double computeHoeffdingBound(double range, double confidence, double n){
- return Math.sqrt((Math.pow(range, 2.0) * Math.log(1.0/confidence)) / (2.0*n));
- }
+ // housekeeping
+ activeLearningNode.endSplitting();
+ activeLearningNode.setWeightSeenAtLastSplitEvaluation(activeLearningNode.getWeightSeen());
+ }
- /**
- * AggregationTimeOutHandler is a class to support time-out feature while waiting for local computation results
- * from the local statistic PIs.
- * @author Arinto Murdopo
- *
- */
- static class AggregationTimeOutHandler implements Runnable{
-
- private static final Logger logger = LoggerFactory.getLogger(AggregationTimeOutHandler.class);
- private final Long splitId;
- private final BlockingQueue<Long> toBeSplittedNodes;
-
- AggregationTimeOutHandler(Long splitId, BlockingQueue<Long> toBeSplittedNodes){
- this.splitId = splitId;
- this.toBeSplittedNodes = toBeSplittedNodes;
- }
+ /**
+ * Helper method to deactivate learning node
+ *
+ * @param toDeactivate
+ * Active Learning Node that will be deactivated
+ * @param parent
+ * Parent of the soon-to-be-deactivated Active LearningNode
+ * @param parentBranch
+ * the branch index of the node in the parent node
+ */
+ private void deactivateLearningNode(ActiveLearningNode toDeactivate, SplitNode parent, int parentBranch) {
+ Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution());
+ if (parent == null) {
+ this.treeRoot = newLeaf;
+ } else {
+ parent.setChild(parentBranch, newLeaf);
+ }
- @Override
- public void run() {
- logger.debug("Time out is reached. AggregationTimeOutHandler is started.");
- try {
- toBeSplittedNodes.put(splitId);
- } catch (InterruptedException e) {
- logger.warn("Interrupted while trying to put the ID into the queue");
- }
- logger.debug("AggregationTimeOutHandler is finished.");
- }
- }
-
- /**
- * SplittingNodeInfo is a class to represents the ActiveLearningNode that is splitting
- * @author Arinto Murdopo
- *
- */
- static class SplittingNodeInfo{
-
- private final ActiveLearningNode activeLearningNode;
- private final FoundNode foundNode;
- private final ScheduledFuture<?> scheduledFuture;
-
- SplittingNodeInfo(ActiveLearningNode activeLearningNode, FoundNode foundNode, ScheduledFuture<?> scheduledFuture){
- this.activeLearningNode = activeLearningNode;
- this.foundNode = foundNode;
- this.scheduledFuture = scheduledFuture;
- }
- }
-
- protected ChangeDetector changeDetector;
+ this.activeLeafNodeCount--;
+ this.inactiveLeafNodeCount++;
+ }
- public ChangeDetector getChangeDetector() {
- return this.changeDetector;
- }
+ private LearningNode newLearningNode(int parallelismHint) {
+ return newLearningNode(new double[0], parallelismHint);
+ }
- public void setChangeDetector(ChangeDetector cd) {
- this.changeDetector = cd;
- }
-
- /**
- * Builder class to replace constructors with many parameters
- * @author Arinto Murdopo
- *
- */
- static class Builder{
-
- //required parameters
- private final Instances dataset;
-
- //default values
- private SplitCriterion splitCriterion = new InfoGainSplitCriterion();
- private double splitConfidence = 0.0000001;
- private double tieThreshold = 0.05;
- private int gracePeriod = 200;
- private int parallelismHint = 1;
- private long timeOut = 30;
- private ChangeDetector changeDetector = null;
+ private LearningNode newLearningNode(double[] initialClassObservations, int parallelismHint) {
+ // for VHT optimization, we need to dynamically instantiate the appropriate
+ // ActiveLearningNode
+ return new ActiveLearningNode(initialClassObservations, parallelismHint);
+ }
- Builder(Instances dataset){
- this.dataset = dataset;
- }
-
- Builder(ModelAggregatorProcessor oldProcessor){
- this.dataset = oldProcessor.dataset;
- this.splitCriterion = oldProcessor.splitCriterion;
- this.splitConfidence = oldProcessor.splitConfidence;
- this.tieThreshold = oldProcessor.tieThreshold;
- this.gracePeriod = oldProcessor.gracePeriod;
- this.parallelismHint = oldProcessor.parallelismHint;
- this.timeOut = oldProcessor.timeOut;
- }
-
- Builder splitCriterion(SplitCriterion splitCriterion){
- this.splitCriterion = splitCriterion;
- return this;
- }
-
- Builder splitConfidence(double splitConfidence){
- this.splitConfidence = splitConfidence;
- return this;
- }
-
- Builder tieThreshold(double tieThreshold){
- this.tieThreshold = tieThreshold;
- return this;
- }
-
- Builder gracePeriod(int gracePeriod){
- this.gracePeriod = gracePeriod;
- return this;
- }
-
- Builder parallelismHint(int parallelismHint){
- this.parallelismHint = parallelismHint;
- return this;
- }
-
- Builder timeOut(long timeOut){
- this.timeOut = timeOut;
- return this;
- }
-
- Builder changeDetector(ChangeDetector changeDetector){
- this.changeDetector = changeDetector;
- return this;
- }
- ModelAggregatorProcessor build(){
- return new ModelAggregatorProcessor(this);
- }
- }
-
+ /**
+ * Helper method to set the model context, i.e. how many attributes they are
+ * and what is the class index
+ *
+ * @param ih
+ */
+ private void setModelContext(InstancesHeader ih) {
+ // TODO possibly refactored
+ if ((ih != null) && (ih.classIndex() < 0)) {
+ throw new IllegalArgumentException(
+ "Context for a classifier must include a class to learn");
+ }
+ // TODO: check flag for checking whether training has started or not
+
+ // model context is used to describe the model
+ logger.trace("Model context: {}", ih.toString());
+ }
+
+ private static double computeHoeffdingBound(double range, double confidence, double n) {
+ return Math.sqrt((Math.pow(range, 2.0) * Math.log(1.0 / confidence)) / (2.0 * n));
+ }
+
+ /**
+ * AggregationTimeOutHandler is a class to support time-out feature while
+ * waiting for local computation results from the local statistic PIs.
+ *
+ * @author Arinto Murdopo
+ *
+ */
+ static class AggregationTimeOutHandler implements Runnable {
+
+ private static final Logger logger = LoggerFactory.getLogger(AggregationTimeOutHandler.class);
+ private final Long splitId;
+ private final BlockingQueue<Long> toBeSplittedNodes;
+
+ AggregationTimeOutHandler(Long splitId, BlockingQueue<Long> toBeSplittedNodes) {
+ this.splitId = splitId;
+ this.toBeSplittedNodes = toBeSplittedNodes;
+ }
+
+ @Override
+ public void run() {
+ logger.debug("Time out is reached. AggregationTimeOutHandler is started.");
+ try {
+ toBeSplittedNodes.put(splitId);
+ } catch (InterruptedException e) {
+ logger.warn("Interrupted while trying to put the ID into the queue");
+ }
+ logger.debug("AggregationTimeOutHandler is finished.");
+ }
+ }
+
+ /**
+ * SplittingNodeInfo is a class to represents the ActiveLearningNode that is
+ * splitting
+ *
+ * @author Arinto Murdopo
+ *
+ */
+ static class SplittingNodeInfo {
+
+ private final ActiveLearningNode activeLearningNode;
+ private final FoundNode foundNode;
+ private final ScheduledFuture<?> scheduledFuture;
+
+ SplittingNodeInfo(ActiveLearningNode activeLearningNode, FoundNode foundNode, ScheduledFuture<?> scheduledFuture) {
+ this.activeLearningNode = activeLearningNode;
+ this.foundNode = foundNode;
+ this.scheduledFuture = scheduledFuture;
+ }
+ }
+
+ protected ChangeDetector changeDetector;
+
+ public ChangeDetector getChangeDetector() {
+ return this.changeDetector;
+ }
+
+ public void setChangeDetector(ChangeDetector cd) {
+ this.changeDetector = cd;
+ }
+
+ /**
+ * Builder class to replace constructors with many parameters
+ *
+ * @author Arinto Murdopo
+ *
+ */
+ static class Builder {
+
+ // required parameters
+ private final Instances dataset;
+
+ // default values
+ private SplitCriterion splitCriterion = new InfoGainSplitCriterion();
+ private double splitConfidence = 0.0000001;
+ private double tieThreshold = 0.05;
+ private int gracePeriod = 200;
+ private int parallelismHint = 1;
+ private long timeOut = 30;
+ private ChangeDetector changeDetector = null;
+
+ Builder(Instances dataset) {
+ this.dataset = dataset;
+ }
+
+ Builder(ModelAggregatorProcessor oldProcessor) {
+ this.dataset = oldProcessor.dataset;
+ this.splitCriterion = oldProcessor.splitCriterion;
+ this.splitConfidence = oldProcessor.splitConfidence;
+ this.tieThreshold = oldProcessor.tieThreshold;
+ this.gracePeriod = oldProcessor.gracePeriod;
+ this.parallelismHint = oldProcessor.parallelismHint;
+ this.timeOut = oldProcessor.timeOut;
+ }
+
+ Builder splitCriterion(SplitCriterion splitCriterion) {
+ this.splitCriterion = splitCriterion;
+ return this;
+ }
+
+ Builder splitConfidence(double splitConfidence) {
+ this.splitConfidence = splitConfidence;
+ return this;
+ }
+
+ Builder tieThreshold(double tieThreshold) {
+ this.tieThreshold = tieThreshold;
+ return this;
+ }
+
+ Builder gracePeriod(int gracePeriod) {
+ this.gracePeriod = gracePeriod;
+ return this;
+ }
+
+ Builder parallelismHint(int parallelismHint) {
+ this.parallelismHint = parallelismHint;
+ return this;
+ }
+
+ Builder timeOut(long timeOut) {
+ this.timeOut = timeOut;
+ return this;
+ }
+
+ Builder changeDetector(ChangeDetector changeDetector) {
+ this.changeDetector = changeDetector;
+ return this;
+ }
+
+ ModelAggregatorProcessor build() {
+ return new ModelAggregatorProcessor(this);
+ }
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/Node.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/Node.java
index ff9bf5f..22e551f 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/Node.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/Node.java
@@ -25,70 +25,83 @@
/**
* Abstract class that represents a node in the tree model.
+ *
* @author Arinto Murdopo
- *
+ *
*/
-abstract class Node implements java.io.Serializable{
-
- private static final long serialVersionUID = 4008521239214180548L;
-
- protected final DoubleVector observedClassDistribution;
-
- /**
- * Method to route/filter an instance into its corresponding leaf. This method will be
- * invoked recursively.
- * @param inst Instance to be routed
- * @param parent Parent of the current node
- * @param parentBranch The index of the current node in the parent
- * @return FoundNode which is the data structure to represent the resulting leaf.
- */
- abstract FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, int parentBranch);
-
- /**
- * Method to return the predicted class of the instance based on the statistic
- * inside the node.
- *
- * @param inst To-be-predicted instance
- * @param map ModelAggregatorProcessor
- * @return The prediction result in the form of class distribution
- */
- abstract double[] getClassVotes(Instance inst, ModelAggregatorProcessor map);
-
- /**
- * Method to check whether the node is a leaf node or not.
- * @return Boolean flag to indicate whether the node is a leaf or not
- */
- abstract boolean isLeaf();
-
-
- /**
- * Constructor of the tree node
- * @param classObservation distribution of the observed classes.
- */
- protected Node(double[] classObservation){
- this.observedClassDistribution = new DoubleVector(classObservation);
- }
-
- /**
- * Getter method for the class distribution
- * @return Observed class distribution
- */
- protected double[] getObservedClassDistribution() {
- return this.observedClassDistribution.getArrayCopy();
- }
-
- /**
- * A method to check whether the class distribution only consists of one class or not.
- * @return Flag whether class distribution is pure or not.
- */
- protected boolean observedClassDistributionIsPure(){
- return (observedClassDistribution.numNonZeroEntries() < 2);
- }
-
- protected void describeSubtree(ModelAggregatorProcessor modelAggrProc, StringBuilder out, int indent){
- //TODO: implement method to gracefully define the tree
- }
-
- //TODO: calculate promise for limiting the model based on the memory size
- //double calculatePromise();
+abstract class Node implements java.io.Serializable {
+
+ private static final long serialVersionUID = 4008521239214180548L;
+
+ protected final DoubleVector observedClassDistribution;
+
+ /**
+ * Method to route/filter an instance into its corresponding leaf. This method
+ * will be invoked recursively.
+ *
+ * @param inst
+ * Instance to be routed
+ * @param parent
+ * Parent of the current node
+ * @param parentBranch
+ * The index of the current node in the parent
+ * @return FoundNode which is the data structure to represent the resulting
+ * leaf.
+ */
+ abstract FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, int parentBranch);
+
+ /**
+ * Method to return the predicted class of the instance based on the statistic
+ * inside the node.
+ *
+ * @param inst
+ * To-be-predicted instance
+ * @param map
+ * ModelAggregatorProcessor
+ * @return The prediction result in the form of class distribution
+ */
+ abstract double[] getClassVotes(Instance inst, ModelAggregatorProcessor map);
+
+ /**
+ * Method to check whether the node is a leaf node or not.
+ *
+ * @return Boolean flag to indicate whether the node is a leaf or not
+ */
+ abstract boolean isLeaf();
+
+ /**
+ * Constructor of the tree node
+ *
+ * @param classObservation
+ * distribution of the observed classes.
+ */
+ protected Node(double[] classObservation) {
+ this.observedClassDistribution = new DoubleVector(classObservation);
+ }
+
+ /**
+ * Getter method for the class distribution
+ *
+ * @return Observed class distribution
+ */
+ protected double[] getObservedClassDistribution() {
+ return this.observedClassDistribution.getArrayCopy();
+ }
+
+ /**
+ * A method to check whether the class distribution only consists of one class
+ * or not.
+ *
+ * @return Flag whether class distribution is pure or not.
+ */
+ protected boolean observedClassDistributionIsPure() {
+ return (observedClassDistribution.numNonZeroEntries() < 2);
+ }
+
+ protected void describeSubtree(ModelAggregatorProcessor modelAggrProc, StringBuilder out, int indent) {
+ // TODO: implement method to gracefully define the tree
+ }
+
+ // TODO: calculate promise for limiting the model based on the memory size
+ // double calculatePromise();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/SplitNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/SplitNode.java
index fd93db1..7c6b434 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/SplitNode.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/SplitNode.java
@@ -25,84 +25,94 @@
import com.yahoo.labs.samoa.instances.Instance;
/**
- * SplitNode represents the node that contains one or more questions in the decision tree model,
- * in order to route the instances into the correct leaf.
+ * SplitNode represents the node that contains one or more questions in the
+ * decision tree model, in order to route the instances into the correct leaf.
+ *
* @author Arinto Murdopo
- *
+ *
*/
public class SplitNode extends Node {
-
- private static final long serialVersionUID = -7380795529928485792L;
-
- private final AutoExpandVector<Node> children;
- protected final InstanceConditionalTest splitTest;
-
- public SplitNode(InstanceConditionalTest splitTest,
- double[] classObservation) {
- super(classObservation);
- this.children = new AutoExpandVector<>();
- this.splitTest = splitTest;
- }
- @Override
- FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, int parentBranch) {
- int childIndex = instanceChildIndex(inst);
- if(childIndex >= 0){
- Node child = getChild(childIndex);
- if(child != null){
- return child.filterInstanceToLeaf(inst, this, childIndex);
- }
- return new FoundNode(null, this, childIndex);
- }
- return new FoundNode(this, parent, parentBranch);
- }
+ private static final long serialVersionUID = -7380795529928485792L;
- @Override
- boolean isLeaf() {
- return false;
- }
-
- @Override
- double[] getClassVotes(Instance inst, ModelAggregatorProcessor vht) {
- return this.observedClassDistribution.getArrayCopy();
- }
-
- /**
- * Method to return the number of children of this split node
- * @return number of children
- */
- int numChildren(){
- return this.children.size();
- }
-
- /**
- * Method to set the children in a specific index of the SplitNode with the appropriate child
- * @param index Index of the child in the SplitNode
- * @param child The child node
- */
- void setChild(int index, Node child){
- if ((this.splitTest.maxBranches() >= 0)
- && (index >= this.splitTest.maxBranches())) {
- throw new IndexOutOfBoundsException();
- }
- this.children.set(index, child);
- }
-
- /**
- * Method to get the child node given the index
- * @param index The child node index
- * @return The child node in the given index
- */
- Node getChild(int index){
- return this.children.get(index);
- }
-
- /**
- * Method to route the instance using this split node
- * @param inst The routed instance
- * @return The index of the branch where the instance is routed
- */
- int instanceChildIndex(Instance inst){
- return this.splitTest.branchForInstance(inst);
- }
+ private final AutoExpandVector<Node> children;
+ protected final InstanceConditionalTest splitTest;
+
+ public SplitNode(InstanceConditionalTest splitTest,
+ double[] classObservation) {
+ super(classObservation);
+ this.children = new AutoExpandVector<>();
+ this.splitTest = splitTest;
+ }
+
+ @Override
+ FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, int parentBranch) {
+ int childIndex = instanceChildIndex(inst);
+ if (childIndex >= 0) {
+ Node child = getChild(childIndex);
+ if (child != null) {
+ return child.filterInstanceToLeaf(inst, this, childIndex);
+ }
+ return new FoundNode(null, this, childIndex);
+ }
+ return new FoundNode(this, parent, parentBranch);
+ }
+
+ @Override
+ boolean isLeaf() {
+ return false;
+ }
+
+ @Override
+ double[] getClassVotes(Instance inst, ModelAggregatorProcessor vht) {
+ return this.observedClassDistribution.getArrayCopy();
+ }
+
+ /**
+ * Method to return the number of children of this split node
+ *
+ * @return number of children
+ */
+ int numChildren() {
+ return this.children.size();
+ }
+
+ /**
+ * Method to set the children in a specific index of the SplitNode with the
+ * appropriate child
+ *
+ * @param index
+ * Index of the child in the SplitNode
+ * @param child
+ * The child node
+ */
+ void setChild(int index, Node child) {
+ if ((this.splitTest.maxBranches() >= 0)
+ && (index >= this.splitTest.maxBranches())) {
+ throw new IndexOutOfBoundsException();
+ }
+ this.children.set(index, child);
+ }
+
+ /**
+ * Method to get the child node given the index
+ *
+ * @param index
+ * The child node index
+ * @return The child node in the given index
+ */
+ Node getChild(int index) {
+ return this.children.get(index);
+ }
+
+ /**
+ * Method to route the instance using this split node
+ *
+ * @param inst
+ * The routed instance
+ * @return The index of the branch where the instance is routed
+ */
+ int instanceChildIndex(Instance inst) {
+ return this.splitTest.branchForInstance(inst);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java
index e8ccce7..990a15b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java
@@ -46,7 +46,7 @@
* Vertical Hoeffding Tree (VHT) classifier is a distributed classifier that
* utilizes vertical parallelism on top of Very Fast Decision Tree (VFDT)
* classifier.
- *
+ *
* @author Arinto Murdopo
*/
public final class VerticalHoeffdingTree implements ClassificationLearner, AdaptiveLearner, Configurable {
@@ -110,7 +110,6 @@
Stream filterStream = topologyBuilder.createStream(filterProc);
this.filterProc.setOutputStream(filterStream);
-
ModelAggregatorProcessor modelAggrProc = new ModelAggregatorProcessor.Builder(dataset)
.splitCriterion((SplitCriterion) this.splitCriterionOption.getValue())
.splitConfidence(splitConfidenceOption.getValue())
@@ -175,7 +174,7 @@
static class LearningNodeIdGenerator {
- //TODO: add code to warn user of when value reaches Long.MAX_VALUES
+ // TODO: add code to warn user of when value reaches Long.MAX_VALUES
private static long id = 0;
static synchronized long generate() {
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClusteringContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClusteringContentEvent.java
index a0d950b..4b7b1f6 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClusteringContentEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClusteringContentEvent.java
@@ -30,60 +30,61 @@
@Immutable
final public class ClusteringContentEvent implements ContentEvent {
- private static final long serialVersionUID = -7746983521296618922L;
- private Instance instance;
- private boolean isLast = false;
- private String key;
- private boolean isSample;
+ private static final long serialVersionUID = -7746983521296618922L;
+ private Instance instance;
+ private boolean isLast = false;
+ private String key;
+ private boolean isSample;
- public ClusteringContentEvent() {
- // Necessary for kryo serializer
- }
+ public ClusteringContentEvent() {
+ // Necessary for kryo serializer
+ }
- /**
- * Instantiates a new clustering event.
- *
- * @param index
- * the index
- * @param instance
- * the instance
+ /**
+ * Instantiates a new clustering event.
+ *
+ * @param index
+ * the index
+ * @param instance
+ * the instance
+ */
+ public ClusteringContentEvent(long index, Instance instance) {
+ /*
+ * if (instance != null) { this.instance = new
+ * SerializableInstance(instance); }
*/
- public ClusteringContentEvent(long index, Instance instance) {
- /*
- * if (instance != null) { this.instance = new SerializableInstance(instance); }
- */
- this.instance = instance;
- this.setKey(Long.toString(index));
- }
+ this.instance = instance;
+ this.setKey(Long.toString(index));
+ }
- @Override
- public String getKey() {
- return this.key;
- }
+ @Override
+ public String getKey() {
+ return this.key;
+ }
- @Override
- public void setKey(String str) {
- this.key = str;
- }
+ @Override
+ public void setKey(String str) {
+ this.key = str;
+ }
- @Override
- public boolean isLastEvent() {
- return this.isLast;
- }
+ @Override
+ public boolean isLastEvent() {
+ return this.isLast;
+ }
- public void setLast(boolean isLast) {
- this.isLast = isLast;
- }
+ public void setLast(boolean isLast) {
+ this.isLast = isLast;
+ }
- public Instance getInstance() {
- return this.instance;
- }
+ public Instance getInstance() {
+ return this.instance;
+ }
- public boolean isSample() {
- return isSample;
- }
+ public boolean isSample() {
+ return isSample;
+ }
- public void setSample(boolean b) {
- this.isSample = b;
- }
+ public void setSample(boolean b) {
+ this.isSample = b;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClustreamClustererAdapter.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClustreamClustererAdapter.java
index 057e37b..829d448 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClustreamClustererAdapter.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClustreamClustererAdapter.java
@@ -31,136 +31,142 @@
import com.yahoo.labs.samoa.moa.clusterers.clustream.Clustream;
/**
- *
- * Base class for adapting Clustream clusterer.
- *
+ *
+ * Base class for adapting Clustream clusterer.
+ *
*/
public class ClustreamClustererAdapter implements LocalClustererAdapter, Configurable {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 4372366401338704353L;
-
- public ClassOption learnerOption = new ClassOption("learner", 'l',
- "Clusterer to train.", com.yahoo.labs.samoa.moa.clusterers.Clusterer.class, Clustream.class.getName());
- /**
- * The learner.
- */
- protected com.yahoo.labs.samoa.moa.clusterers.Clusterer learner;
-
- /**
- * The is init.
- */
- protected Boolean isInit;
-
- /**
- * The dataset.
- */
- protected Instances dataset;
+ private static final long serialVersionUID = 4372366401338704353L;
- @Override
- public void setDataset(Instances dataset) {
- this.dataset = dataset;
- }
+ public ClassOption learnerOption = new ClassOption("learner", 'l',
+ "Clusterer to train.", com.yahoo.labs.samoa.moa.clusterers.Clusterer.class, Clustream.class.getName());
+ /**
+ * The learner.
+ */
+ protected com.yahoo.labs.samoa.moa.clusterers.Clusterer learner;
- /**
- * Instantiates a new learner.
- *
- * @param learner the learner
- * @param dataset the dataset
- */
- public ClustreamClustererAdapter(com.yahoo.labs.samoa.moa.clusterers.Clusterer learner, Instances dataset) {
- this.learner = learner.copy();
- this.isInit = false;
- this.dataset = dataset;
- }
+ /**
+ * The is init.
+ */
+ protected Boolean isInit;
- /**
- * Instantiates a new learner.
- *
- * @param learner the learner
- * @param dataset the dataset
- */
- public ClustreamClustererAdapter() {
- this.learner = ((com.yahoo.labs.samoa.moa.clusterers.Clusterer) this.learnerOption.getValue()).copy();
- this.isInit = false;
- //this.dataset = dataset;
- }
+ /**
+ * The dataset.
+ */
+ protected Instances dataset;
- /**
- * Creates a new learner object.
- *
- * @return the learner
- */
- @Override
- public ClustreamClustererAdapter create() {
- ClustreamClustererAdapter l = new ClustreamClustererAdapter(learner, dataset);
- if (dataset == null) {
- System.out.println("dataset null while creating");
- }
- return l;
- }
+ @Override
+ public void setDataset(Instances dataset) {
+ this.dataset = dataset;
+ }
- /**
- * Trains this classifier incrementally using the given instance.
- *
- * @param inst the instance to be used for training
- */
- @Override
- public void trainOnInstance(Instance inst) {
- if (this.isInit == false) {
- this.isInit = true;
- InstancesHeader instances = new InstancesHeader(dataset);
- this.learner.setModelContext(instances);
- this.learner.prepareForUse();
- }
- if (inst.weight() > 0) {
- inst.setDataset(dataset);
- learner.trainOnInstance(inst);
- }
- }
+ /**
+ * Instantiates a new learner.
+ *
+ * @param learner
+ * the learner
+ * @param dataset
+ * the dataset
+ */
+ public ClustreamClustererAdapter(com.yahoo.labs.samoa.moa.clusterers.Clusterer learner, Instances dataset) {
+ this.learner = learner.copy();
+ this.isInit = false;
+ this.dataset = dataset;
+ }
- /**
- * Predicts the class memberships for a given instance. If an instance is
- * unclassified, the returned array elements must be all zero.
- *
- * @param inst the instance to be classified
- * @return an array containing the estimated membership probabilities of the
- * test instance in each class
- */
- @Override
- public double[] getVotesForInstance(Instance inst) {
- double[] ret;
- inst.setDataset(dataset);
- if (this.isInit == false) {
- ret = new double[dataset.numClasses()];
- } else {
- ret = learner.getVotesForInstance(inst);
- }
- return ret;
- }
+ /**
+ * Instantiates a new learner.
+ *
+ * @param learner
+ * the learner
+ * @param dataset
+ * the dataset
+ */
+ public ClustreamClustererAdapter() {
+ this.learner = ((com.yahoo.labs.samoa.moa.clusterers.Clusterer) this.learnerOption.getValue()).copy();
+ this.isInit = false;
+ // this.dataset = dataset;
+ }
- /**
- * Resets this classifier. It must be similar to starting a new classifier
- * from scratch.
- *
- */
- @Override
- public void resetLearning() {
- learner.resetLearning();
+ /**
+ * Creates a new learner object.
+ *
+ * @return the learner
+ */
+ @Override
+ public ClustreamClustererAdapter create() {
+ ClustreamClustererAdapter l = new ClustreamClustererAdapter(learner, dataset);
+ if (dataset == null) {
+ System.out.println("dataset null while creating");
}
+ return l;
+ }
- public boolean implementsMicroClusterer() {
- return this.learner.implementsMicroClusterer();
+ /**
+ * Trains this classifier incrementally using the given instance.
+ *
+ * @param inst
+ * the instance to be used for training
+ */
+ @Override
+ public void trainOnInstance(Instance inst) {
+ if (this.isInit == false) {
+ this.isInit = true;
+ InstancesHeader instances = new InstancesHeader(dataset);
+ this.learner.setModelContext(instances);
+ this.learner.prepareForUse();
}
+ if (inst.weight() > 0) {
+ inst.setDataset(dataset);
+ learner.trainOnInstance(inst);
+ }
+ }
- public Clustering getMicroClusteringResult() {
- return this.learner.getMicroClusteringResult();
+ /**
+ * Predicts the class memberships for a given instance. If an instance is
+ * unclassified, the returned array elements must be all zero.
+ *
+ * @param inst
+ * the instance to be classified
+ * @return an array containing the estimated membership probabilities of the
+ * test instance in each class
+ */
+ @Override
+ public double[] getVotesForInstance(Instance inst) {
+ double[] ret;
+ inst.setDataset(dataset);
+ if (this.isInit == false) {
+ ret = new double[dataset.numClasses()];
+ } else {
+ ret = learner.getVotesForInstance(inst);
}
+ return ret;
+ }
- public Instances getDataset() {
- return this.dataset;
- }
+ /**
+ * Resets this classifier. It must be similar to starting a new classifier
+ * from scratch.
+ *
+ */
+ @Override
+ public void resetLearning() {
+ learner.resetLearning();
+ }
+
+ public boolean implementsMicroClusterer() {
+ return this.learner.implementsMicroClusterer();
+ }
+
+ public Clustering getMicroClusteringResult() {
+ return this.learner.getMicroClusteringResult();
+ }
+
+ public Instances getDataset() {
+ return this.dataset;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererAdapter.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererAdapter.java
index fedbcfe..4e4cd6e 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererAdapter.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererAdapter.java
@@ -31,52 +31,52 @@
* @author abifet
*/
public interface LocalClustererAdapter extends Serializable {
-
- /**
- * Creates a new learner object.
- *
- * @return the learner
- */
- LocalClustererAdapter create();
- /**
- * Predicts the class memberships for a given instance. If an instance is
- * unclassified, the returned array elements must be all zero.
- *
- * @param inst
- * the instance to be classified
- * @return an array containing the estimated membership probabilities of the
- * test instance in each class
- */
- double[] getVotesForInstance(Instance inst);
+ /**
+ * Creates a new learner object.
+ *
+ * @return the learner
+ */
+ LocalClustererAdapter create();
- /**
- * Resets this classifier. It must be similar to starting a new classifier
- * from scratch.
- *
- */
- void resetLearning();
+ /**
+ * Predicts the class memberships for a given instance. If an instance is
+ * unclassified, the returned array elements must be all zero.
+ *
+ * @param inst
+ * the instance to be classified
+ * @return an array containing the estimated membership probabilities of the
+ * test instance in each class
+ */
+ double[] getVotesForInstance(Instance inst);
- /**
- * Trains this classifier incrementally using the given instance.
- *
- * @param inst
- * the instance to be used for training
- */
- void trainOnInstance(Instance inst);
-
- /**
- * Sets where to obtain the information of attributes of Instances
- *
- * @param dataset
- * the dataset that contains the information
- */
- public void setDataset(Instances dataset);
-
- public Instances getDataset();
+ /**
+ * Resets this classifier. It must be similar to starting a new classifier
+ * from scratch.
+ *
+ */
+ void resetLearning();
- public boolean implementsMicroClusterer();
+ /**
+ * Trains this classifier incrementally using the given instance.
+ *
+ * @param inst
+ * the instance to be used for training
+ */
+ void trainOnInstance(Instance inst);
- public Clustering getMicroClusteringResult();
-
+ /**
+ * Sets where to obtain the information of attributes of Instances
+ *
+ * @param dataset
+ * the dataset that contains the information
+ */
+ public void setDataset(Instances dataset);
+
+ public Instances getDataset();
+
+ public boolean implementsMicroClusterer();
+
+ public Clustering getMicroClusteringResult();
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererProcessor.java
index a397539..4688ba2 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererProcessor.java
@@ -33,6 +33,7 @@
import com.yahoo.labs.samoa.topology.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+
//import weka.core.Instance;
/**
@@ -40,152 +41,160 @@
*/
final public class LocalClustererProcessor implements Processor {
- /**
+ /**
*
*/
- private static final long serialVersionUID = -1577910988699148691L;
- private static final Logger logger = LoggerFactory
- .getLogger(LocalClustererProcessor.class);
- private LocalClustererAdapter model;
- private Stream outputStream;
- private int modelId;
- private long instancesCount = 0;
- private long sampleFrequency = 1000;
+ private static final long serialVersionUID = -1577910988699148691L;
+ private static final Logger logger = LoggerFactory
+ .getLogger(LocalClustererProcessor.class);
+ private LocalClustererAdapter model;
+ private Stream outputStream;
+ private int modelId;
+ private long instancesCount = 0;
+ private long sampleFrequency = 1000;
- public long getSampleFrequency() {
- return sampleFrequency;
+ public long getSampleFrequency() {
+ return sampleFrequency;
+ }
+
+ public void setSampleFrequency(long sampleFrequency) {
+ this.sampleFrequency = sampleFrequency;
+ }
+
+ /**
+ * Sets the learner.
+ *
+ * @param model
+ * the model to set
+ */
+ public void setLearner(LocalClustererAdapter model) {
+ this.model = model;
+ }
+
+ /**
+ * Gets the learner.
+ *
+ * @return the model
+ */
+ public LocalClustererAdapter getLearner() {
+ return model;
+ }
+
+ /**
+ * Set the output streams.
+ *
+ * @param outputStream
+ * the new output stream {@link PredictionCombinerPE}.
+ */
+ public void setOutputStream(Stream outputStream) {
+
+ this.outputStream = outputStream;
+ }
+
+ /**
+ * Gets the output stream.
+ *
+ * @return the output stream
+ */
+ public Stream getOutputStream() {
+ return outputStream;
+ }
+
+ /**
+ * Gets the instances count.
+ *
+ * @return number of observation vectors used in training iteration.
+ */
+ public long getInstancesCount() {
+ return instancesCount;
+ }
+
+ /**
+ * Update stats.
+ *
+ * @param event
+ * the event
+ */
+ private void updateStats(ContentEvent event) {
+ Instance instance;
+ if (event instanceof ClusteringContentEvent) {
+ // Local Clustering
+ ClusteringContentEvent ev = (ClusteringContentEvent) event;
+ instance = ev.getInstance();
+ DataPoint point = new DataPoint(instance, Integer.parseInt(event.getKey()));
+ model.trainOnInstance(point);
+ instancesCount++;
}
- public void setSampleFrequency(long sampleFrequency) {
- this.sampleFrequency = sampleFrequency;
+ if (event instanceof ClusteringResultContentEvent) {
+ // Global Clustering
+ ClusteringResultContentEvent ev = (ClusteringResultContentEvent) event;
+ Clustering clustering = ev.getClustering();
+
+ for (int i = 0; i < clustering.size(); i++) {
+ instance = new DenseInstance(1.0, clustering.get(i).getCenter());
+ instance.setDataset(model.getDataset());
+ DataPoint point = new DataPoint(instance, Integer.parseInt(event.getKey()));
+ model.trainOnInstance(point);
+ instancesCount++;
+ }
}
- /**
- * Sets the learner.
- *
- * @param model the model to set
- */
- public void setLearner(LocalClustererAdapter model) {
- this.model = model;
+ if (instancesCount % this.sampleFrequency == 0) {
+ logger.info("Trained model using {} events with classifier id {}", instancesCount, this.modelId); // getId());
+ }
+ }
+
+ /**
+ * On event.
+ *
+ * @param event
+ * the event
+ * @return true, if successful
+ */
+ @Override
+ public boolean process(ContentEvent event) {
+
+ if (event.isLastEvent() ||
+ (instancesCount > 0 && instancesCount % this.sampleFrequency == 0)) {
+ if (model.implementsMicroClusterer()) {
+
+ Clustering clustering = model.getMicroClusteringResult();
+
+ ClusteringResultContentEvent resultEvent = new ClusteringResultContentEvent(clustering, event.isLastEvent());
+
+ this.outputStream.put(resultEvent);
+ }
}
- /**
- * Gets the learner.
- *
- * @return the model
- */
- public LocalClustererAdapter getLearner() {
- return model;
+ updateStats(event);
+ return false;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.Processor#onCreate(int)
+ */
+ @Override
+ public void onCreate(int id) {
+ this.modelId = id;
+ model = model.create();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
+ */
+ @Override
+ public Processor newProcessor(Processor sourceProcessor) {
+ LocalClustererProcessor newProcessor = new LocalClustererProcessor();
+ LocalClustererProcessor originProcessor = (LocalClustererProcessor) sourceProcessor;
+ if (originProcessor.getLearner() != null) {
+ newProcessor.setLearner(originProcessor.getLearner().create());
}
-
- /**
- * Set the output streams.
- *
- * @param outputStream the new output stream {@link PredictionCombinerPE}.
- */
- public void setOutputStream(Stream outputStream) {
-
- this.outputStream = outputStream;
- }
-
- /**
- * Gets the output stream.
- *
- * @return the output stream
- */
- public Stream getOutputStream() {
- return outputStream;
- }
-
- /**
- * Gets the instances count.
- *
- * @return number of observation vectors used in training iteration.
- */
- public long getInstancesCount() {
- return instancesCount;
- }
-
- /**
- * Update stats.
- *
- * @param event the event
- */
- private void updateStats(ContentEvent event) {
- Instance instance;
- if (event instanceof ClusteringContentEvent){
- //Local Clustering
- ClusteringContentEvent ev = (ClusteringContentEvent) event;
- instance = ev.getInstance();
- DataPoint point = new DataPoint(instance, Integer.parseInt(event.getKey()));
- model.trainOnInstance(point);
- instancesCount++;
- }
-
- if (event instanceof ClusteringResultContentEvent){
- //Global Clustering
- ClusteringResultContentEvent ev = (ClusteringResultContentEvent) event;
- Clustering clustering = ev.getClustering();
-
- for (int i=0; i<clustering.size(); i++) {
- instance = new DenseInstance(1.0,clustering.get(i).getCenter());
- instance.setDataset(model.getDataset());
- DataPoint point = new DataPoint(instance, Integer.parseInt(event.getKey()));
- model.trainOnInstance(point);
- instancesCount++;
- }
- }
-
- if (instancesCount % this.sampleFrequency == 0) {
- logger.info("Trained model using {} events with classifier id {}", instancesCount, this.modelId); // getId());
- }
- }
-
- /**
- * On event.
- *
- * @param event the event
- * @return true, if successful
- */
- @Override
- public boolean process(ContentEvent event) {
-
- if (event.isLastEvent() ||
- (instancesCount > 0 && instancesCount% this.sampleFrequency == 0)) {
- if (model.implementsMicroClusterer()) {
-
- Clustering clustering = model.getMicroClusteringResult();
-
- ClusteringResultContentEvent resultEvent = new ClusteringResultContentEvent(clustering, event.isLastEvent());
-
- this.outputStream.put(resultEvent);
- }
- }
-
- updateStats(event);
- return false;
- }
-
- /* (non-Javadoc)
- * @see samoa.core.Processor#onCreate(int)
- */
- @Override
- public void onCreate(int id) {
- this.modelId = id;
- model = model.create();
- }
-
- /* (non-Javadoc)
- * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
- */
- @Override
- public Processor newProcessor(Processor sourceProcessor) {
- LocalClustererProcessor newProcessor = new LocalClustererProcessor();
- LocalClustererProcessor originProcessor = (LocalClustererProcessor) sourceProcessor;
- if (originProcessor.getLearner() != null) {
- newProcessor.setLearner(originProcessor.getLearner().create());
- }
- newProcessor.setOutputStream(originProcessor.getOutputStream());
- return newProcessor;
- }
+ newProcessor.setOutputStream(originProcessor.getOutputStream());
+ return newProcessor;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java
index 894a0cc..ae8684b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java
@@ -42,56 +42,59 @@
*/
public final class SingleLearner implements Learner, Configurable {
- private static final long serialVersionUID = 684111382631697031L;
-
- private LocalClustererProcessor learnerP;
-
- private Stream resultStream;
-
- private Instances dataset;
+ private static final long serialVersionUID = 684111382631697031L;
- public ClassOption learnerOption = new ClassOption("learner", 'l',
- "Learner to train.", LocalClustererAdapter.class, ClustreamClustererAdapter.class.getName());
-
- private TopologyBuilder builder;
+ private LocalClustererProcessor learnerP;
- private int parallelism;
+ private Stream resultStream;
- @Override
- public void init(TopologyBuilder builder, Instances dataset, int parallelism){
- this.builder = builder;
- this.dataset = dataset;
- this.parallelism = parallelism;
- this.setLayout();
- }
+ private Instances dataset;
+ public ClassOption learnerOption = new ClassOption("learner", 'l',
+ "Learner to train.", LocalClustererAdapter.class, ClustreamClustererAdapter.class.getName());
- protected void setLayout() {
- learnerP = new LocalClustererProcessor();
- LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue();
- learner.setDataset(this.dataset);
- learnerP.setLearner(learner);
-
- this.builder.addProcessor(learnerP, this.parallelism);
- resultStream = this.builder.createStream(learnerP);
-
- learnerP.setOutputStream(resultStream);
- }
+ private TopologyBuilder builder;
- /* (non-Javadoc)
- * @see samoa.classifiers.Classifier#getInputProcessingItem()
- */
- @Override
- public Processor getInputProcessor() {
- return learnerP;
- }
-
- /* (non-Javadoc)
- * @see samoa.learners.Learner#getResultStreams()
- */
- @Override
- public Set<Stream> getResultStreams() {
- Set<Stream> streams = ImmutableSet.of(this.resultStream);
- return streams;
- }
+ private int parallelism;
+
+ @Override
+ public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
+ this.builder = builder;
+ this.dataset = dataset;
+ this.parallelism = parallelism;
+ this.setLayout();
+ }
+
+ protected void setLayout() {
+ learnerP = new LocalClustererProcessor();
+ LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue();
+ learner.setDataset(this.dataset);
+ learnerP.setLearner(learner);
+
+ this.builder.addProcessor(learnerP, this.parallelism);
+ resultStream = this.builder.createStream(learnerP);
+
+ learnerP.setOutputStream(resultStream);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.classifiers.Classifier#getInputProcessingItem()
+ */
+ @Override
+ public Processor getInputProcessor() {
+ return learnerP;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.learners.Learner#getResultStreams()
+ */
+ @Override
+ public Set<Stream> getResultStreams() {
+ Set<Stream> streams = ImmutableSet.of(this.resultStream);
+ return streams;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java
index e75a1bd..7d266f5 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java
@@ -34,65 +34,67 @@
*/
public class ClusteringDistributorProcessor implements Processor {
- private static final long serialVersionUID = -1550901409625192730L;
+ private static final long serialVersionUID = -1550901409625192730L;
- private Stream outputStream;
- private Stream evaluationStream;
- private int numInstances;
+ private Stream outputStream;
+ private Stream evaluationStream;
+ private int numInstances;
- public Stream getOutputStream() {
- return outputStream;
+ public Stream getOutputStream() {
+ return outputStream;
+ }
+
+ public void setOutputStream(Stream outputStream) {
+ this.outputStream = outputStream;
+ }
+
+ public Stream getEvaluationStream() {
+ return evaluationStream;
+ }
+
+ public void setEvaluationStream(Stream evaluationStream) {
+ this.evaluationStream = evaluationStream;
+ }
+
+ /**
+ * Process event.
+ *
+ * @param event
+ * the event
+ * @return true, if successful
+ */
+ public boolean process(ContentEvent event) {
+ // distinguish between ClusteringContentEvent and
+ // ClusteringEvaluationContentEvent
+ if (event instanceof ClusteringContentEvent) {
+ ClusteringContentEvent cce = (ClusteringContentEvent) event;
+ outputStream.put(event);
+ if (cce.isSample()) {
+ evaluationStream.put(new ClusteringEvaluationContentEvent(null,
+ new DataPoint(cce.getInstance(), numInstances++), cce.isLastEvent()));
+ }
+ } else if (event instanceof ClusteringEvaluationContentEvent) {
+ evaluationStream.put(event);
}
+ return true;
+ }
- public void setOutputStream(Stream outputStream) {
- this.outputStream = outputStream;
- }
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
+ */
+ @Override
+ public Processor newProcessor(Processor sourceProcessor) {
+ ClusteringDistributorProcessor newProcessor = new ClusteringDistributorProcessor();
+ ClusteringDistributorProcessor originProcessor = (ClusteringDistributorProcessor) sourceProcessor;
+ if (originProcessor.getOutputStream() != null)
+ newProcessor.setOutputStream(originProcessor.getOutputStream());
+ if (originProcessor.getEvaluationStream() != null)
+ newProcessor.setEvaluationStream(originProcessor.getEvaluationStream());
+ return newProcessor;
+ }
- public Stream getEvaluationStream() {
- return evaluationStream;
- }
-
- public void setEvaluationStream(Stream evaluationStream) {
- this.evaluationStream = evaluationStream;
- }
-
- /**
- * Process event.
- *
- * @param event
- * the event
- * @return true, if successful
- */
- public boolean process(ContentEvent event) {
- // distinguish between ClusteringContentEvent and ClusteringEvaluationContentEvent
- if (event instanceof ClusteringContentEvent) {
- ClusteringContentEvent cce = (ClusteringContentEvent) event;
- outputStream.put(event);
- if (cce.isSample()) {
- evaluationStream.put(new ClusteringEvaluationContentEvent(null, new DataPoint(cce.getInstance(), numInstances++), cce.isLastEvent()));
- }
- } else if (event instanceof ClusteringEvaluationContentEvent) {
- evaluationStream.put(event);
- }
- return true;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
- */
- @Override
- public Processor newProcessor(Processor sourceProcessor) {
- ClusteringDistributorProcessor newProcessor = new ClusteringDistributorProcessor();
- ClusteringDistributorProcessor originProcessor = (ClusteringDistributorProcessor) sourceProcessor;
- if (originProcessor.getOutputStream() != null)
- newProcessor.setOutputStream(originProcessor.getOutputStream());
- if (originProcessor.getEvaluationStream() != null)
- newProcessor.setEvaluationStream(originProcessor.getEvaluationStream());
- return newProcessor;
- }
-
- public void onCreate(int id) {
- }
+ public void onCreate(int id) {
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java
index d924733..edfecfa 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java
@@ -45,74 +45,75 @@
*/
public final class DistributedClusterer implements Learner, Configurable {
- private static final long serialVersionUID = 684111382631697031L;
+ private static final long serialVersionUID = 684111382631697031L;
- private Stream resultStream;
+ private Stream resultStream;
- private Instances dataset;
+ private Instances dataset;
- public ClassOption learnerOption = new ClassOption("learner", 'l', "Clusterer to use.", LocalClustererAdapter.class,
- ClustreamClustererAdapter.class.getName());
+ public ClassOption learnerOption = new ClassOption("learner", 'l', "Clusterer to use.", LocalClustererAdapter.class,
+ ClustreamClustererAdapter.class.getName());
- public IntOption paralellismOption = new IntOption("paralellismOption", 'P', "The paralellism level for concurrent processes", 2, 1, Integer.MAX_VALUE);
+ public IntOption paralellismOption = new IntOption("paralellismOption", 'P',
+ "The paralellism level for concurrent processes", 2, 1, Integer.MAX_VALUE);
- private TopologyBuilder builder;
+ private TopologyBuilder builder;
-// private ClusteringDistributorProcessor distributorP;
- private LocalClustererProcessor learnerP;
+ // private ClusteringDistributorProcessor distributorP;
+ private LocalClustererProcessor learnerP;
-// private Stream distributorToLocalStream;
- private Stream localToGlobalStream;
+ // private Stream distributorToLocalStream;
+ private Stream localToGlobalStream;
-// private int parallelism;
+ // private int parallelism;
- @Override
- public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
- this.builder = builder;
- this.dataset = dataset;
-// this.parallelism = parallelism;
- this.setLayout();
- }
+ @Override
+ public void init(TopologyBuilder builder, Instances dataset, int parallelism) {
+ this.builder = builder;
+ this.dataset = dataset;
+ // this.parallelism = parallelism;
+ this.setLayout();
+ }
- protected void setLayout() {
- // Distributor
-// distributorP = new ClusteringDistributorProcessor();
-// this.builder.addProcessor(distributorP, parallelism);
-// distributorToLocalStream = this.builder.createStream(distributorP);
-// distributorP.setOutputStream(distributorToLocalStream);
-// distributorToGlobalStream = this.builder.createStream(distributorP);
+ protected void setLayout() {
+ // Distributor
+ // distributorP = new ClusteringDistributorProcessor();
+ // this.builder.addProcessor(distributorP, parallelism);
+ // distributorToLocalStream = this.builder.createStream(distributorP);
+ // distributorP.setOutputStream(distributorToLocalStream);
+ // distributorToGlobalStream = this.builder.createStream(distributorP);
- // Local Clustering
- learnerP = new LocalClustererProcessor();
- LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue();
- learner.setDataset(this.dataset);
- learnerP.setLearner(learner);
- builder.addProcessor(learnerP, this.paralellismOption.getValue());
- localToGlobalStream = this.builder.createStream(learnerP);
- learnerP.setOutputStream(localToGlobalStream);
+ // Local Clustering
+ learnerP = new LocalClustererProcessor();
+ LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue();
+ learner.setDataset(this.dataset);
+ learnerP.setLearner(learner);
+ builder.addProcessor(learnerP, this.paralellismOption.getValue());
+ localToGlobalStream = this.builder.createStream(learnerP);
+ learnerP.setOutputStream(localToGlobalStream);
- // Global Clustering
- LocalClustererProcessor globalClusteringCombinerP = new LocalClustererProcessor();
- LocalClustererAdapter globalLearner = (LocalClustererAdapter) this.learnerOption.getValue();
- globalLearner.setDataset(this.dataset);
- globalClusteringCombinerP.setLearner(learner);
- builder.addProcessor(globalClusteringCombinerP, 1);
- builder.connectInputAllStream(localToGlobalStream, globalClusteringCombinerP);
+ // Global Clustering
+ LocalClustererProcessor globalClusteringCombinerP = new LocalClustererProcessor();
+ LocalClustererAdapter globalLearner = (LocalClustererAdapter) this.learnerOption.getValue();
+ globalLearner.setDataset(this.dataset);
+ globalClusteringCombinerP.setLearner(learner);
+ builder.addProcessor(globalClusteringCombinerP, 1);
+ builder.connectInputAllStream(localToGlobalStream, globalClusteringCombinerP);
- // Output Stream
- resultStream = this.builder.createStream(globalClusteringCombinerP);
- globalClusteringCombinerP.setOutputStream(resultStream);
- }
+ // Output Stream
+ resultStream = this.builder.createStream(globalClusteringCombinerP);
+ globalClusteringCombinerP.setOutputStream(resultStream);
+ }
- @Override
- public Processor getInputProcessor() {
-// return distributorP;
- return learnerP;
- }
+ @Override
+ public Processor getInputProcessor() {
+ // return distributorP;
+ return learnerP;
+ }
- @Override
- public Set<Stream> getResultStreams() {
- Set<Stream> streams = ImmutableSet.of(this.resultStream);
- return streams;
- }
+ @Override
+ public Set<Stream> getResultStreams() {
+ Set<Stream> streams = ImmutableSet.of(this.resultStream);
+ return streams;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java
index 37303ec..45fa228 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java
@@ -21,60 +21,63 @@
*/
import com.yahoo.labs.samoa.moa.core.SerializeUtils;
+
//import moa.core.SizeOf;
/**
- * Abstract MOA Object. All classes that are serializable, copiable,
- * can measure its size, and can give a description, extend this class.
- *
+ * Abstract MOA Object. All classes that are serializable, copiable, can measure
+ * its size, and can give a description, extend this class.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public abstract class AbstractMOAObject implements MOAObject {
- @Override
- public MOAObject copy() {
- return copy(this);
- }
+ @Override
+ public MOAObject copy() {
+ return copy(this);
+ }
- @Override
- public int measureByteSize() {
- return measureByteSize(this);
- }
+ @Override
+ public int measureByteSize() {
+ return measureByteSize(this);
+ }
- /**
- * Returns a description of the object.
- *
- * @return a description of the object
- */
- @Override
- public String toString() {
- StringBuilder sb = new StringBuilder();
- getDescription(sb, 0);
- return sb.toString();
- }
+ /**
+ * Returns a description of the object.
+ *
+ * @return a description of the object
+ */
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ getDescription(sb, 0);
+ return sb.toString();
+ }
- /**
- * This method produces a copy of an object.
- *
- * @param obj object to copy
- * @return a copy of the object
- */
- public static MOAObject copy(MOAObject obj) {
- try {
- return (MOAObject) SerializeUtils.copyObject(obj);
- } catch (Exception e) {
- throw new RuntimeException("Object copy failed.", e);
- }
+ /**
+ * This method produces a copy of an object.
+ *
+ * @param obj
+ * object to copy
+ * @return a copy of the object
+ */
+ public static MOAObject copy(MOAObject obj) {
+ try {
+ return (MOAObject) SerializeUtils.copyObject(obj);
+ } catch (Exception e) {
+ throw new RuntimeException("Object copy failed.", e);
}
+ }
- /**
- * Gets the memory size of an object.
- *
- * @param obj object to measure the memory size
- * @return the memory size of this object
- */
- public static int measureByteSize(MOAObject obj) {
- return 0; //(int) SizeOf.fullSizeOf(obj);
- }
+ /**
+ * Gets the memory size of an object.
+ *
+ * @param obj
+ * object to measure the memory size
+ * @return the memory size of this object
+ */
+ public static int measureByteSize(MOAObject obj) {
+ return 0; // (int) SizeOf.fullSizeOf(obj);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java
index cc26eaa..cd98892 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java
@@ -23,36 +23,38 @@
import java.io.Serializable;
/**
- * Interface implemented by classes in MOA, so that all are serializable,
- * can produce copies of their objects, and can measure its memory size.
- * They also give a string description.
- *
+ * Interface implemented by classes in MOA, so that all are serializable, can
+ * produce copies of their objects, and can measure its memory size. They also
+ * give a string description.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public interface MOAObject extends Serializable {
- /**
- * Gets the memory size of this object.
- *
- * @return the memory size of this object
- */
- public int measureByteSize();
+ /**
+ * Gets the memory size of this object.
+ *
+ * @return the memory size of this object
+ */
+ public int measureByteSize();
- /**
- * This method produces a copy of this object.
- *
- * @return a copy of this object
- */
- public MOAObject copy();
+ /**
+ * This method produces a copy of this object.
+ *
+ * @return a copy of this object
+ */
+ public MOAObject copy();
- /**
- * Returns a string representation of this object.
- * Used in <code>AbstractMOAObject.toString</code>
- * to give a string representation of the object.
- *
- * @param sb the stringbuilder to add the description
- * @param indent the number of characters to indent
- */
- public void getDescription(StringBuilder sb, int indent);
+ /**
+ * Returns a string representation of this object. Used in
+ * <code>AbstractMOAObject.toString</code> to give a string representation of
+ * the object.
+ *
+ * @param sb
+ * the stringbuilder to add the description
+ * @param indent
+ * the number of characters to indent
+ */
+ public void getDescription(StringBuilder sb, int indent);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/AbstractClassifier.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/AbstractClassifier.java
index 09de49e..21bbf4b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/AbstractClassifier.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/AbstractClassifier.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.classifiers;
/*
@@ -41,338 +40,347 @@
public abstract class AbstractClassifier extends AbstractOptionHandler implements Classifier {
- @Override
- public String getPurposeString() {
- return "MOA Classifier: " + getClass().getCanonicalName();
+ @Override
+ public String getPurposeString() {
+ return "MOA Classifier: " + getClass().getCanonicalName();
+ }
+
+ /** Header of the instances of the data stream */
+ protected InstancesHeader modelContext;
+
+ /** Sum of the weights of the instances trained by this model */
+ protected double trainingWeightSeenByModel = 0.0;
+
+ /** Random seed used in randomizable learners */
+ protected int randomSeed = 1;
+
+ /** Option for randomizable learners to change the random seed */
+ protected IntOption randomSeedOption;
+
+ /** Random Generator used in randomizable learners */
+ public Random classifierRandom;
+
+ /**
+ * Creates an classifier and setups the random seed option if the classifier
+ * is randomizable.
+ */
+ public AbstractClassifier() {
+ if (isRandomizable()) {
+ this.randomSeedOption = new IntOption("randomSeed", 'r',
+ "Seed for random behaviour of the classifier.", 1);
}
+ }
- /** Header of the instances of the data stream */
- protected InstancesHeader modelContext;
+ @Override
+ public void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ if (this.randomSeedOption != null) {
+ this.randomSeed = this.randomSeedOption.getValue();
+ }
+ if (!trainingHasStarted()) {
+ resetLearning();
+ }
+ }
- /** Sum of the weights of the instances trained by this model */
- protected double trainingWeightSeenByModel = 0.0;
+ @Override
+ public double[] getVotesForInstance(Example<Instance> example) {
+ return getVotesForInstance(example.getData());
+ }
- /** Random seed used in randomizable learners */
- protected int randomSeed = 1;
+ @Override
+ public abstract double[] getVotesForInstance(Instance inst);
- /** Option for randomizable learners to change the random seed */
- protected IntOption randomSeedOption;
+ @Override
+ public void setModelContext(InstancesHeader ih) {
+ if ((ih != null) && (ih.classIndex() < 0)) {
+ throw new IllegalArgumentException(
+ "Context for a classifier must include a class to learn");
+ }
+ if (trainingHasStarted()
+ && (this.modelContext != null)
+ && ((ih == null) || !contextIsCompatible(this.modelContext, ih))) {
+ throw new IllegalArgumentException(
+ "New context is not compatible with existing model");
+ }
+ this.modelContext = ih;
+ }
- /** Random Generator used in randomizable learners */
- public Random classifierRandom;
+ @Override
+ public InstancesHeader getModelContext() {
+ return this.modelContext;
+ }
- /**
- * Creates an classifier and setups the random seed option
- * if the classifier is randomizable.
- */
- public AbstractClassifier() {
- if (isRandomizable()) {
- this.randomSeedOption = new IntOption("randomSeed", 'r',
- "Seed for random behaviour of the classifier.", 1);
+ @Override
+ public void setRandomSeed(int s) {
+ this.randomSeed = s;
+ if (this.randomSeedOption != null) {
+ // keep option consistent
+ this.randomSeedOption.setValue(s);
+ }
+ }
+
+ @Override
+ public boolean trainingHasStarted() {
+ return this.trainingWeightSeenByModel > 0.0;
+ }
+
+ @Override
+ public double trainingWeightSeenByModel() {
+ return this.trainingWeightSeenByModel;
+ }
+
+ @Override
+ public void resetLearning() {
+ this.trainingWeightSeenByModel = 0.0;
+ if (isRandomizable()) {
+ this.classifierRandom = new Random(this.randomSeed);
+ }
+ resetLearningImpl();
+ }
+
+ @Override
+ public void trainOnInstance(Instance inst) {
+ if (inst.weight() > 0.0) {
+ this.trainingWeightSeenByModel += inst.weight();
+ trainOnInstanceImpl(inst);
+ }
+ }
+
+ @Override
+ public Measurement[] getModelMeasurements() {
+ List<Measurement> measurementList = new LinkedList<>();
+ measurementList.add(new Measurement("model training instances",
+ trainingWeightSeenByModel()));
+ measurementList.add(new Measurement("model serialized size (bytes)",
+ measureByteSize()));
+ Measurement[] modelMeasurements = getModelMeasurementsImpl();
+ if (modelMeasurements != null) {
+ measurementList.addAll(Arrays.asList(modelMeasurements));
+ }
+ // add average of sub-model measurements
+ Learner[] subModels = getSublearners();
+ if ((subModels != null) && (subModels.length > 0)) {
+ List<Measurement[]> subMeasurements = new LinkedList<>();
+ for (Learner subModel : subModels) {
+ if (subModel != null) {
+ subMeasurements.add(subModel.getModelMeasurements());
}
+ }
+ Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements
+ .toArray(new Measurement[subMeasurements.size()][]));
+ measurementList.addAll(Arrays.asList(avgMeasurements));
}
+ return measurementList.toArray(new Measurement[measurementList.size()]);
+ }
- @Override
- public void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- if (this.randomSeedOption != null) {
- this.randomSeed = this.randomSeedOption.getValue();
+ @Override
+ public void getDescription(StringBuilder out, int indent) {
+ StringUtils.appendIndented(out, indent, "Model type: ");
+ out.append(this.getClass().getName());
+ StringUtils.appendNewline(out);
+ Measurement.getMeasurementsDescription(getModelMeasurements(), out,
+ indent);
+ StringUtils.appendNewlineIndented(out, indent, "Model description:");
+ StringUtils.appendNewline(out);
+ if (trainingHasStarted()) {
+ getModelDescription(out, indent);
+ } else {
+ StringUtils.appendIndented(out, indent,
+ "Model has not been trained.");
+ }
+ }
+
+ @Override
+ public Learner[] getSublearners() {
+ return null;
+ }
+
+ @Override
+ public Classifier[] getSubClassifiers() {
+ return null;
+ }
+
+ @Override
+ public Classifier copy() {
+ return (Classifier) super.copy();
+ }
+
+ @Override
+ public MOAObject getModel() {
+ return this;
+ }
+
+ @Override
+ public void trainOnInstance(Example<Instance> example) {
+ trainOnInstance(example.getData());
+ }
+
+ @Override
+ public boolean correctlyClassifies(Instance inst) {
+ return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst.classValue();
+ }
+
+ /**
+ * Gets the name of the attribute of the class from the header.
+ *
+ * @return the string with name of the attribute of the class
+ */
+ public String getClassNameString() {
+ return InstancesHeader.getClassNameString(this.modelContext);
+ }
+
+ /**
+ * Gets the name of a label of the class from the header.
+ *
+ * @param classLabelIndex
+ * the label index
+ * @return the name of the label of the class
+ */
+ public String getClassLabelString(int classLabelIndex) {
+ return InstancesHeader.getClassLabelString(this.modelContext,
+ classLabelIndex);
+ }
+
+ /**
+ * Gets the name of an attribute from the header.
+ *
+ * @param attIndex
+ * the attribute index
+ * @return the name of the attribute
+ */
+ public String getAttributeNameString(int attIndex) {
+ return InstancesHeader.getAttributeNameString(this.modelContext, attIndex);
+ }
+
+ /**
+ * Gets the name of a value of an attribute from the header.
+ *
+ * @param attIndex
+ * the attribute index
+ * @param valIndex
+ * the value of the attribute
+ * @return the name of the value of the attribute
+ */
+ public String getNominalValueString(int attIndex, int valIndex) {
+ return InstancesHeader.getNominalValueString(this.modelContext, attIndex, valIndex);
+ }
+
+ /**
+ * Returns if two contexts or headers of instances are compatible.<br>
+ * <br>
+ *
+ * Two contexts are compatible if they follow the following rules:<br>
+ * Rule 1: num classes can increase but never decrease<br>
+ * Rule 2: num attributes can increase but never decrease<br>
+ * Rule 3: num nominal attribute values can increase but never decrease<br>
+ * Rule 4: attribute types must stay in the same order (although class can
+ * move; is always skipped over)<br>
+ * <br>
+ *
+ * Attribute names are free to change, but should always still represent the
+ * original attributes.
+ *
+ * @param originalContext
+ * the first context to compare
+ * @param newContext
+ * the second context to compare
+ * @return true if the two contexts are compatible.
+ */
+ public static boolean contextIsCompatible(InstancesHeader originalContext,
+ InstancesHeader newContext) {
+
+ if (newContext.numClasses() < originalContext.numClasses()) {
+ return false; // rule 1
+ }
+ if (newContext.numAttributes() < originalContext.numAttributes()) {
+ return false; // rule 2
+ }
+ int oPos = 0;
+ int nPos = 0;
+ while (oPos < originalContext.numAttributes()) {
+ if (oPos == originalContext.classIndex()) {
+ oPos++;
+ if (!(oPos < originalContext.numAttributes())) {
+ break;
}
- if (!trainingHasStarted()) {
- resetLearning();
+ }
+ if (nPos == newContext.classIndex()) {
+ nPos++;
+ }
+ if (originalContext.attribute(oPos).isNominal()) {
+ if (!newContext.attribute(nPos).isNominal()) {
+ return false; // rule 4
}
- }
-
-
- @Override
- public double[] getVotesForInstance(Example<Instance> example){
- return getVotesForInstance(example.getData());
- }
-
- @Override
- public abstract double[] getVotesForInstance(Instance inst);
-
- @Override
- public void setModelContext(InstancesHeader ih) {
- if ((ih != null) && (ih.classIndex() < 0)) {
- throw new IllegalArgumentException(
- "Context for a classifier must include a class to learn");
+ if (newContext.attribute(nPos).numValues() < originalContext.attribute(oPos).numValues()) {
+ return false; // rule 3
}
- if (trainingHasStarted()
- && (this.modelContext != null)
- && ((ih == null) || !contextIsCompatible(this.modelContext, ih))) {
- throw new IllegalArgumentException(
- "New context is not compatible with existing model");
+ } else {
+ assert (originalContext.attribute(oPos).isNumeric());
+ if (!newContext.attribute(nPos).isNumeric()) {
+ return false; // rule 4
}
- this.modelContext = ih;
+ }
+ oPos++;
+ nPos++;
}
+ return true; // all checks clear
+ }
- @Override
- public InstancesHeader getModelContext() {
- return this.modelContext;
- }
+ /**
+ * Resets this classifier. It must be similar to starting a new classifier
+ * from scratch. <br>
+ * <br>
+ *
+ * The reason for ...Impl methods: ease programmer burden by not requiring
+ * them to remember calls to super in overridden methods. Note that this will
+ * produce compiler errors if not overridden.
+ */
+ public abstract void resetLearningImpl();
- @Override
- public void setRandomSeed(int s) {
- this.randomSeed = s;
- if (this.randomSeedOption != null) {
- // keep option consistent
- this.randomSeedOption.setValue(s);
- }
- }
+ /**
+ * Trains this classifier incrementally using the given instance.<br>
+ * <br>
+ *
+ * The reason for ...Impl methods: ease programmer burden by not requiring
+ * them to remember calls to super in overridden methods. Note that this will
+ * produce compiler errors if not overridden.
+ *
+ * @param inst
+ * the instance to be used for training
+ */
+ public abstract void trainOnInstanceImpl(Instance inst);
- @Override
- public boolean trainingHasStarted() {
- return this.trainingWeightSeenByModel > 0.0;
- }
+ /**
+ * Gets the current measurements of this classifier.<br>
+ * <br>
+ *
+ * The reason for ...Impl methods: ease programmer burden by not requiring
+ * them to remember calls to super in overridden methods. Note that this will
+ * produce compiler errors if not overridden.
+ *
+ * @return an array of measurements to be used in evaluation tasks
+ */
+ protected abstract Measurement[] getModelMeasurementsImpl();
- @Override
- public double trainingWeightSeenByModel() {
- return this.trainingWeightSeenByModel;
- }
+ /**
+ * Returns a string representation of the model.
+ *
+ * @param out
+ * the stringbuilder to add the description
+ * @param indent
+ * the number of characters to indent
+ */
+ public abstract void getModelDescription(StringBuilder out, int indent);
- @Override
- public void resetLearning() {
- this.trainingWeightSeenByModel = 0.0;
- if (isRandomizable()) {
- this.classifierRandom = new Random(this.randomSeed);
- }
- resetLearningImpl();
- }
-
- @Override
- public void trainOnInstance(Instance inst) {
- if (inst.weight() > 0.0) {
- this.trainingWeightSeenByModel += inst.weight();
- trainOnInstanceImpl(inst);
- }
- }
-
- @Override
- public Measurement[] getModelMeasurements() {
- List<Measurement> measurementList = new LinkedList<>();
- measurementList.add(new Measurement("model training instances",
- trainingWeightSeenByModel()));
- measurementList.add(new Measurement("model serialized size (bytes)",
- measureByteSize()));
- Measurement[] modelMeasurements = getModelMeasurementsImpl();
- if (modelMeasurements != null) {
- measurementList.addAll(Arrays.asList(modelMeasurements));
- }
- // add average of sub-model measurements
- Learner[] subModels = getSublearners();
- if ((subModels != null) && (subModels.length > 0)) {
- List<Measurement[]> subMeasurements = new LinkedList<>();
- for (Learner subModel : subModels) {
- if (subModel != null) {
- subMeasurements.add(subModel.getModelMeasurements());
- }
- }
- Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][]));
- measurementList.addAll(Arrays.asList(avgMeasurements));
- }
- return measurementList.toArray(new Measurement[measurementList.size()]);
- }
-
- @Override
- public void getDescription(StringBuilder out, int indent) {
- StringUtils.appendIndented(out, indent, "Model type: ");
- out.append(this.getClass().getName());
- StringUtils.appendNewline(out);
- Measurement.getMeasurementsDescription(getModelMeasurements(), out,
- indent);
- StringUtils.appendNewlineIndented(out, indent, "Model description:");
- StringUtils.appendNewline(out);
- if (trainingHasStarted()) {
- getModelDescription(out, indent);
- } else {
- StringUtils.appendIndented(out, indent,
- "Model has not been trained.");
- }
- }
-
- @Override
- public Learner[] getSublearners() {
- return null;
- }
-
-
- @Override
- public Classifier[] getSubClassifiers() {
- return null;
- }
-
-
- @Override
- public Classifier copy() {
- return (Classifier) super.copy();
- }
-
-
- @Override
- public MOAObject getModel(){
- return this;
- }
-
- @Override
- public void trainOnInstance(Example<Instance> example){
- trainOnInstance(example.getData());
- }
-
- @Override
- public boolean correctlyClassifies(Instance inst) {
- return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst.classValue();
- }
-
- /**
- * Gets the name of the attribute of the class from the header.
- *
- * @return the string with name of the attribute of the class
- */
- public String getClassNameString() {
- return InstancesHeader.getClassNameString(this.modelContext);
- }
-
- /**
- * Gets the name of a label of the class from the header.
- *
- * @param classLabelIndex the label index
- * @return the name of the label of the class
- */
- public String getClassLabelString(int classLabelIndex) {
- return InstancesHeader.getClassLabelString(this.modelContext,
- classLabelIndex);
- }
-
- /**
- * Gets the name of an attribute from the header.
- *
- * @param attIndex the attribute index
- * @return the name of the attribute
- */
- public String getAttributeNameString(int attIndex) {
- return InstancesHeader.getAttributeNameString(this.modelContext, attIndex);
- }
-
- /**
- * Gets the name of a value of an attribute from the header.
- *
- * @param attIndex the attribute index
- * @param valIndex the value of the attribute
- * @return the name of the value of the attribute
- */
- public String getNominalValueString(int attIndex, int valIndex) {
- return InstancesHeader.getNominalValueString(this.modelContext, attIndex, valIndex);
- }
-
-
- /**
- * Returns if two contexts or headers of instances are compatible.<br><br>
- *
- * Two contexts are compatible if they follow the following rules:<br>
- * Rule 1: num classes can increase but never decrease<br>
- * Rule 2: num attributes can increase but never decrease<br>
- * Rule 3: num nominal attribute values can increase but never decrease<br>
- * Rule 4: attribute types must stay in the same order (although class
- * can move; is always skipped over)<br><br>
- *
- * Attribute names are free to change, but should always still represent
- * the original attributes.
- *
- * @param originalContext the first context to compare
- * @param newContext the second context to compare
- * @return true if the two contexts are compatible.
- */
- public static boolean contextIsCompatible(InstancesHeader originalContext,
- InstancesHeader newContext) {
-
- if (newContext.numClasses() < originalContext.numClasses()) {
- return false; // rule 1
- }
- if (newContext.numAttributes() < originalContext.numAttributes()) {
- return false; // rule 2
- }
- int oPos = 0;
- int nPos = 0;
- while (oPos < originalContext.numAttributes()) {
- if (oPos == originalContext.classIndex()) {
- oPos++;
- if (!(oPos < originalContext.numAttributes())) {
- break;
- }
- }
- if (nPos == newContext.classIndex()) {
- nPos++;
- }
- if (originalContext.attribute(oPos).isNominal()) {
- if (!newContext.attribute(nPos).isNominal()) {
- return false; // rule 4
- }
- if (newContext.attribute(nPos).numValues() < originalContext.attribute(oPos).numValues()) {
- return false; // rule 3
- }
- } else {
- assert (originalContext.attribute(oPos).isNumeric());
- if (!newContext.attribute(nPos).isNumeric()) {
- return false; // rule 4
- }
- }
- oPos++;
- nPos++;
- }
- return true; // all checks clear
- }
-
-
-
- /**
- * Resets this classifier. It must be similar to
- * starting a new classifier from scratch. <br><br>
- *
- * The reason for ...Impl methods: ease programmer burden by not requiring
- * them to remember calls to super in overridden methods.
- * Note that this will produce compiler errors if not overridden.
- */
- public abstract void resetLearningImpl();
-
- /**
- * Trains this classifier incrementally using the given instance.<br><br>
- *
- * The reason for ...Impl methods: ease programmer burden by not requiring
- * them to remember calls to super in overridden methods.
- * Note that this will produce compiler errors if not overridden.
- *
- * @param inst the instance to be used for training
- */
- public abstract void trainOnInstanceImpl(Instance inst);
-
- /**
- * Gets the current measurements of this classifier.<br><br>
- *
- * The reason for ...Impl methods: ease programmer burden by not requiring
- * them to remember calls to super in overridden methods.
- * Note that this will produce compiler errors if not overridden.
- *
- * @return an array of measurements to be used in evaluation tasks
- */
- protected abstract Measurement[] getModelMeasurementsImpl();
-
- /**
- * Returns a string representation of the model.
- *
- * @param out the stringbuilder to add the description
- * @param indent the number of characters to indent
- */
- public abstract void getModelDescription(StringBuilder out, int indent);
-
- /**
- * Gets the index of the attribute in the instance,
- * given the index of the attribute in the learner.
- *
- * @param index the index of the attribute in the learner
- * @return the index in the instance
- */
- protected static int modelAttIndexToInstanceAttIndex(int index) {
- return index; //inst.classIndex() > index ? index : index + 1;
- }
+ /**
+ * Gets the index of the attribute in the instance, given the index of the
+ * attribute in the learner.
+ *
+ * @param index
+ * the index of the attribute in the learner
+ * @return the index in the instance
+ */
+ protected static int modelAttIndexToInstanceAttIndex(int index) {
+ return index; // inst.classIndex() > index ? index : index + 1;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Classifier.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Classifier.java
index efbc918..bdda15a 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Classifier.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Classifier.java
@@ -26,52 +26,55 @@
/**
* Classifier interface for incremental classification models.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public interface Classifier extends Learner<Example<Instance>> {
- /**
- * Gets the classifiers of this ensemble. Returns null if this learner is a
- * single learner.
- *
- * @return an array of the learners of the ensemble
- */
- public Classifier[] getSubClassifiers();
+ /**
+ * Gets the classifiers of this ensemble. Returns null if this learner is a
+ * single learner.
+ *
+ * @return an array of the learners of the ensemble
+ */
+ public Classifier[] getSubClassifiers();
- /**
- * Produces a copy of this learner.
- *
- * @return the copy of this learner
- */
- public Classifier copy();
+ /**
+ * Produces a copy of this learner.
+ *
+ * @return the copy of this learner
+ */
+ public Classifier copy();
- /**
- * Gets whether this classifier correctly classifies an instance. Uses
- * getVotesForInstance to obtain the prediction and the instance to obtain
- * its true class.
- *
- *
- * @param inst the instance to be classified
- * @return true if the instance is correctly classified
- */
- public boolean correctlyClassifies(Instance inst);
+ /**
+ * Gets whether this classifier correctly classifies an instance. Uses
+ * getVotesForInstance to obtain the prediction and the instance to obtain its
+ * true class.
+ *
+ *
+ * @param inst
+ * the instance to be classified
+ * @return true if the instance is correctly classified
+ */
+ public boolean correctlyClassifies(Instance inst);
- /**
- * Trains this learner incrementally using the given example.
- *
- * @param inst the instance to be used for training
- */
- public void trainOnInstance(Instance inst);
+ /**
+ * Trains this learner incrementally using the given example.
+ *
+ * @param inst
+ * the instance to be used for training
+ */
+ public void trainOnInstance(Instance inst);
- /**
- * Predicts the class memberships for a given instance. If an instance is
- * unclassified, the returned array elements must be all zero.
- *
- * @param inst the instance to be classified
- * @return an array containing the estimated membership probabilities of the
- * test instance in each class
- */
- public double[] getVotesForInstance(Instance inst);
+ /**
+ * Predicts the class memberships for a given instance. If an instance is
+ * unclassified, the returned array elements must be all zero.
+ *
+ * @param inst
+ * the instance to be classified
+ * @return an array containing the estimated membership probabilities of the
+ * test instance in each class
+ */
+ public double[] getVotesForInstance(Instance inst);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Regressor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Regressor.java
index 758f5c4..53d86a2 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Regressor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Regressor.java
@@ -21,11 +21,12 @@
*/
/**
- * Regressor interface for incremental regression models. It is used only in the GUI Regression Tab.
- *
+ * Regressor interface for incremental regression models. It is used only in the
+ * GUI Regression Tab.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public interface Regressor {
-
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/AttributeSplitSuggestion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/AttributeSplitSuggestion.java
index 1ecc9ed..e469c72 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/AttributeSplitSuggestion.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/AttributeSplitSuggestion.java
@@ -25,44 +25,45 @@
/**
* Class for computing attribute split suggestions given a split test.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class AttributeSplitSuggestion extends AbstractMOAObject implements Comparable<AttributeSplitSuggestion> {
-
- private static final long serialVersionUID = 1L;
- public InstanceConditionalTest splitTest;
+ private static final long serialVersionUID = 1L;
- public double[][] resultingClassDistributions;
+ public InstanceConditionalTest splitTest;
- public double merit;
-
- public AttributeSplitSuggestion() {}
+ public double[][] resultingClassDistributions;
- public AttributeSplitSuggestion(InstanceConditionalTest splitTest,
- double[][] resultingClassDistributions, double merit) {
- this.splitTest = splitTest;
- this.resultingClassDistributions = resultingClassDistributions.clone();
- this.merit = merit;
- }
+ public double merit;
- public int numSplits() {
- return this.resultingClassDistributions.length;
- }
+ public AttributeSplitSuggestion() {
+ }
- public double[] resultingClassDistributionFromSplit(int splitIndex) {
- return this.resultingClassDistributions[splitIndex].clone();
- }
+ public AttributeSplitSuggestion(InstanceConditionalTest splitTest,
+ double[][] resultingClassDistributions, double merit) {
+ this.splitTest = splitTest;
+ this.resultingClassDistributions = resultingClassDistributions.clone();
+ this.merit = merit;
+ }
- @Override
- public int compareTo(AttributeSplitSuggestion comp) {
- return Double.compare(this.merit, comp.merit);
- }
+ public int numSplits() {
+ return this.resultingClassDistributions.length;
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // do nothing
- }
+ public double[] resultingClassDistributionFromSplit(int splitIndex) {
+ return this.resultingClassDistributions[splitIndex].clone();
+ }
+
+ @Override
+ public int compareTo(AttributeSplitSuggestion comp) {
+ return Double.compare(this.merit, comp.merit);
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // do nothing
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java
index d6adc2e..a6cdf80 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java
@@ -25,49 +25,57 @@
import com.yahoo.labs.samoa.moa.options.OptionHandler;
/**
- * Interface for observing the class data distribution for an attribute.
- * This observer monitors the class distribution of a given attribute.
- * Used in naive Bayes and decision trees to monitor data statistics on leaves.
- *
+ * Interface for observing the class data distribution for an attribute. This
+ * observer monitors the class distribution of a given attribute. Used in naive
+ * Bayes and decision trees to monitor data statistics on leaves.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
- * @version $Revision: 7 $
+ * @version $Revision: 7 $
*/
public interface AttributeClassObserver extends OptionHandler {
- /**
- * Updates statistics of this observer given an attribute value, a class
- * and the weight of the instance observed
- *
- * @param attVal the value of the attribute
- * @param classVal the class
- * @param weight the weight of the instance
- */
- public void observeAttributeClass(double attVal, int classVal, double weight);
+ /**
+ * Updates statistics of this observer given an attribute value, a class and
+ * the weight of the instance observed
+ *
+ * @param attVal
+ * the value of the attribute
+ * @param classVal
+ * the class
+ * @param weight
+ * the weight of the instance
+ */
+ public void observeAttributeClass(double attVal, int classVal, double weight);
- /**
- * Gets the probability for an attribute value given a class
- *
- * @param attVal the attribute value
- * @param classVal the class
- * @return probability for an attribute value given a class
- */
- public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal);
+ /**
+ * Gets the probability for an attribute value given a class
+ *
+ * @param attVal
+ * the attribute value
+ * @param classVal
+ * the class
+ * @return probability for an attribute value given a class
+ */
+ public double probabilityOfAttributeValueGivenClass(double attVal,
+ int classVal);
- /**
- * Gets the best split suggestion given a criterion and a class distribution
- *
- * @param criterion the split criterion to use
- * @param preSplitDist the class distribution before the split
- * @param attIndex the attribute index
- * @param binaryOnly true to use binary splits
- * @return suggestion of best attribute split
- */
- public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
- SplitCriterion criterion, double[] preSplitDist, int attIndex,
- boolean binaryOnly);
+ /**
+ * Gets the best split suggestion given a criterion and a class distribution
+ *
+ * @param criterion
+ * the split criterion to use
+ * @param preSplitDist
+ * the class distribution before the split
+ * @param attIndex
+ * the attribute index
+ * @param binaryOnly
+ * true to use binary splits
+ * @return suggestion of best attribute split
+ */
+ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
+ SplitCriterion criterion, double[] preSplitDist, int attIndex,
+ boolean binaryOnly);
+ public void observeAttributeTarget(double attVal, double target);
- public void observeAttributeTarget(double attVal, double target);
-
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java
index e9bb2f9..2b45209 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java
@@ -30,154 +30,155 @@
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
/**
- * Class for observing the class data distribution for a numeric attribute using a binary tree.
- * This observer monitors the class distribution of a given attribute.
- * Used in naive Bayes and decision trees to monitor data statistics on leaves.
- *
+ * Class for observing the class data distribution for a numeric attribute using
+ * a binary tree. This observer monitors the class distribution of a given
+ * attribute. Used in naive Bayes and decision trees to monitor data statistics
+ * on leaves.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class BinaryTreeNumericAttributeClassObserver extends AbstractOptionHandler
- implements NumericAttributeClassObserver {
+ implements NumericAttributeClassObserver {
+
+ private static final long serialVersionUID = 1L;
+
+ public class Node implements Serializable {
private static final long serialVersionUID = 1L;
- public class Node implements Serializable {
+ public double cut_point;
- private static final long serialVersionUID = 1L;
+ public DoubleVector classCountsLeft = new DoubleVector();
- public double cut_point;
+ public DoubleVector classCountsRight = new DoubleVector();
- public DoubleVector classCountsLeft = new DoubleVector();
+ public Node left;
- public DoubleVector classCountsRight = new DoubleVector();
+ public Node right;
- public Node left;
-
- public Node right;
-
- public Node(double val, int label, double weight) {
- this.cut_point = val;
- this.classCountsLeft.addToValue(label, weight);
- }
-
- public void insertValue(double val, int label, double weight) {
- if (val == this.cut_point) {
- this.classCountsLeft.addToValue(label, weight);
- } else if (val <= this.cut_point) {
- this.classCountsLeft.addToValue(label, weight);
- if (this.left == null) {
- this.left = new Node(val, label, weight);
- } else {
- this.left.insertValue(val, label, weight);
- }
- } else { // val > cut_point
- this.classCountsRight.addToValue(label, weight);
- if (this.right == null) {
- this.right = new Node(val, label, weight);
- } else {
- this.right.insertValue(val, label, weight);
- }
- }
- }
+ public Node(double val, int label, double weight) {
+ this.cut_point = val;
+ this.classCountsLeft.addToValue(label, weight);
}
- public Node root = null;
-
- @Override
- public void observeAttributeClass(double attVal, int classVal, double weight) {
- if (Double.isNaN(attVal)) { //Instance.isMissingValue(attVal)
+ public void insertValue(double val, int label, double weight) {
+ if (val == this.cut_point) {
+ this.classCountsLeft.addToValue(label, weight);
+ } else if (val <= this.cut_point) {
+ this.classCountsLeft.addToValue(label, weight);
+ if (this.left == null) {
+ this.left = new Node(val, label, weight);
} else {
- if (this.root == null) {
- this.root = new Node(attVal, classVal, weight);
- } else {
- this.root.insertValue(attVal, classVal, weight);
- }
+ this.left.insertValue(val, label, weight);
}
- }
-
- @Override
- public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal) {
- // TODO: NaiveBayes broken until implemented
- return 0.0;
- }
-
- @Override
- public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
- SplitCriterion criterion, double[] preSplitDist, int attIndex,
- boolean binaryOnly) {
- return searchForBestSplitOption(this.root, null, null, null, null, false,
- criterion, preSplitDist, attIndex);
- }
-
- protected AttributeSplitSuggestion searchForBestSplitOption(
- Node currentNode, AttributeSplitSuggestion currentBestOption,
- double[] actualParentLeft,
- double[] parentLeft, double[] parentRight, boolean leftChild,
- SplitCriterion criterion, double[] preSplitDist, int attIndex) {
- if (currentNode == null) {
- return currentBestOption;
- }
- DoubleVector leftDist = new DoubleVector();
- DoubleVector rightDist = new DoubleVector();
- if (parentLeft == null) {
- leftDist.addValues(currentNode.classCountsLeft);
- rightDist.addValues(currentNode.classCountsRight);
+ } else { // val > cut_point
+ this.classCountsRight.addToValue(label, weight);
+ if (this.right == null) {
+ this.right = new Node(val, label, weight);
} else {
- leftDist.addValues(parentLeft);
- rightDist.addValues(parentRight);
- if (leftChild) {
- //get the exact statistics of the parent value
- DoubleVector exactParentDist = new DoubleVector();
- exactParentDist.addValues(actualParentLeft);
- exactParentDist.subtractValues(currentNode.classCountsLeft);
- exactParentDist.subtractValues(currentNode.classCountsRight);
-
- // move the subtrees
- leftDist.subtractValues(currentNode.classCountsRight);
- rightDist.addValues(currentNode.classCountsRight);
-
- // move the exact value from the parent
- rightDist.addValues(exactParentDist);
- leftDist.subtractValues(exactParentDist);
-
- } else {
- leftDist.addValues(currentNode.classCountsLeft);
- rightDist.subtractValues(currentNode.classCountsLeft);
- }
+ this.right.insertValue(val, label, weight);
}
- double[][] postSplitDists = new double[][]{leftDist.getArrayRef(),
- rightDist.getArrayRef()};
- double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
- if ((currentBestOption == null) || (merit > currentBestOption.merit)) {
- currentBestOption = new AttributeSplitSuggestion(
- new NumericAttributeBinaryTest(attIndex,
- currentNode.cut_point, true), postSplitDists, merit);
+ }
+ }
+ }
- }
- currentBestOption = searchForBestSplitOption(currentNode.left,
- currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], true,
- criterion, preSplitDist, attIndex);
- currentBestOption = searchForBestSplitOption(currentNode.right,
- currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], false,
- criterion, preSplitDist, attIndex);
- return currentBestOption;
- }
+ public Node root = null;
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
+ @Override
+ public void observeAttributeClass(double attVal, int classVal, double weight) {
+ if (Double.isNaN(attVal)) { // Instance.isMissingValue(attVal)
+ } else {
+ if (this.root == null) {
+ this.root = new Node(attVal, classVal, weight);
+ } else {
+ this.root.insertValue(attVal, classVal, weight);
+ }
}
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- // TODO Auto-generated method stub
+ @Override
+ public double probabilityOfAttributeValueGivenClass(double attVal,
+ int classVal) {
+ // TODO: NaiveBayes broken until implemented
+ return 0.0;
+ }
+
+ @Override
+ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
+ SplitCriterion criterion, double[] preSplitDist, int attIndex,
+ boolean binaryOnly) {
+ return searchForBestSplitOption(this.root, null, null, null, null, false,
+ criterion, preSplitDist, attIndex);
+ }
+
+ protected AttributeSplitSuggestion searchForBestSplitOption(
+ Node currentNode, AttributeSplitSuggestion currentBestOption,
+ double[] actualParentLeft,
+ double[] parentLeft, double[] parentRight, boolean leftChild,
+ SplitCriterion criterion, double[] preSplitDist, int attIndex) {
+ if (currentNode == null) {
+ return currentBestOption;
}
-
- @Override
- public void observeAttributeTarget(double attVal, double target) {
- throw new UnsupportedOperationException("Not supported yet.");
+ DoubleVector leftDist = new DoubleVector();
+ DoubleVector rightDist = new DoubleVector();
+ if (parentLeft == null) {
+ leftDist.addValues(currentNode.classCountsLeft);
+ rightDist.addValues(currentNode.classCountsRight);
+ } else {
+ leftDist.addValues(parentLeft);
+ rightDist.addValues(parentRight);
+ if (leftChild) {
+ // get the exact statistics of the parent value
+ DoubleVector exactParentDist = new DoubleVector();
+ exactParentDist.addValues(actualParentLeft);
+ exactParentDist.subtractValues(currentNode.classCountsLeft);
+ exactParentDist.subtractValues(currentNode.classCountsRight);
+
+ // move the subtrees
+ leftDist.subtractValues(currentNode.classCountsRight);
+ rightDist.addValues(currentNode.classCountsRight);
+
+ // move the exact value from the parent
+ rightDist.addValues(exactParentDist);
+ leftDist.subtractValues(exactParentDist);
+
+ } else {
+ leftDist.addValues(currentNode.classCountsLeft);
+ rightDist.subtractValues(currentNode.classCountsLeft);
+ }
}
-
+ double[][] postSplitDists = new double[][] { leftDist.getArrayRef(),
+ rightDist.getArrayRef() };
+ double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
+ if ((currentBestOption == null) || (merit > currentBestOption.merit)) {
+ currentBestOption = new AttributeSplitSuggestion(
+ new NumericAttributeBinaryTest(attIndex,
+ currentNode.cut_point, true), postSplitDists, merit);
+
+ }
+ currentBestOption = searchForBestSplitOption(currentNode.left,
+ currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], true,
+ criterion, preSplitDist, attIndex);
+ currentBestOption = searchForBestSplitOption(currentNode.right,
+ currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], false,
+ criterion, preSplitDist, attIndex);
+ return currentBestOption;
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
+
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
+
+ @Override
+ public void observeAttributeTarget(double attVal, double target) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java
index a68cad9..eeffebd 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers;
/*
@@ -21,7 +20,6 @@
* #L%
*/
-
import java.io.Serializable;
import com.yahoo.labs.samoa.moa.classifiers.core.AttributeSplitSuggestion;
import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SplitCriterion;
@@ -30,119 +28,128 @@
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
/**
- * Class for observing the class data distribution for a numeric attribute using a binary tree.
- * This observer monitors the class distribution of a given attribute.
- *
- * <p>Learning Adaptive Model Rules from High-Speed Data Streams, ECML 2013, E. Almeida, C. Ferreira, P. Kosina and J. Gama; </p>
- *
+ * Class for observing the class data distribution for a numeric attribute using
+ * a binary tree. This observer monitors the class distribution of a given
+ * attribute.
+ *
+ * <p>
+ * Learning Adaptive Model Rules from High-Speed Data Streams, ECML 2013, E.
+ * Almeida, C. Ferreira, P. Kosina and J. Gama;
+ * </p>
+ *
* @author E. Almeida, J. Gama
* @version $Revision: 2$
*/
public class BinaryTreeNumericAttributeClassObserverRegression extends AbstractOptionHandler
- implements NumericAttributeClassObserver {
+ implements NumericAttributeClassObserver {
- public static final long serialVersionUID = 1L;
-
- public class Node implements Serializable {
+ public static final long serialVersionUID = 1L;
- private static final long serialVersionUID = 1L;
+ public class Node implements Serializable {
- public double cut_point;
-
- public double[] lessThan; //This array maintains statistics for the instance reaching the node with attribute values less than or iqual to the cutpoint.
-
- public double[] greaterThan; //This array maintains statistics for the instance reaching the node with attribute values greater than to the cutpoint.
+ private static final long serialVersionUID = 1L;
- public Node left;
+ public double cut_point;
- public Node right;
+ public double[] lessThan; // This array maintains statistics for the
+ // instance reaching the node with attribute
+ // values less than or iqual to the cutpoint.
- public Node(double val, double target) {
- this.cut_point = val;
- this.lessThan = new double[3];
- this.greaterThan = new double[3];
- this.lessThan[0] = target; //The sum of their target attribute values.
- this.lessThan[1] = target * target; //The sum of the squared target attribute values.
- this.lessThan[2] = 1.0; //A counter of the number of instances that have reached the node.
- this.greaterThan[0] = 0.0;
- this.greaterThan[1] = 0.0;
- this.greaterThan[2] = 0.0;
- }
+ public double[] greaterThan; // This array maintains statistics for the
+ // instance reaching the node with attribute
+ // values greater than to the cutpoint.
- public void insertValue(double val, double target) {
- if (val == this.cut_point) {
- this.lessThan[0] = this.lessThan[0] + target;
- this.lessThan[1] = this.lessThan[1] + (target * target);
- this.lessThan[2] = this.lessThan[2] + 1;
- } else if (val <= this.cut_point) {
- this.lessThan[0] = this.lessThan[0] + target;
- this.lessThan[1] = this.lessThan[1] + (target * target);
- this.lessThan[2] = this.lessThan[2] + 1;
- if (this.left == null) {
- this.left = new Node(val, target);
- } else {
- this.left.insertValue(val, target);
- }
- } else {
- this.greaterThan[0] = this.greaterThan[0] + target;
- this.greaterThan[1] = this.greaterThan[1] + (target*target);
- this.greaterThan[2] = this.greaterThan[2] + 1;
- if (this.right == null) {
-
- this.right = new Node(val, target);
- } else {
- this.right.insertValue(val, target);
- }
- }
+ public Node left;
+
+ public Node right;
+
+ public Node(double val, double target) {
+ this.cut_point = val;
+ this.lessThan = new double[3];
+ this.greaterThan = new double[3];
+ this.lessThan[0] = target; // The sum of their target attribute values.
+ this.lessThan[1] = target * target; // The sum of the squared target
+ // attribute values.
+ this.lessThan[2] = 1.0; // A counter of the number of instances that have
+ // reached the node.
+ this.greaterThan[0] = 0.0;
+ this.greaterThan[1] = 0.0;
+ this.greaterThan[2] = 0.0;
+ }
+
+ public void insertValue(double val, double target) {
+ if (val == this.cut_point) {
+ this.lessThan[0] = this.lessThan[0] + target;
+ this.lessThan[1] = this.lessThan[1] + (target * target);
+ this.lessThan[2] = this.lessThan[2] + 1;
+ } else if (val <= this.cut_point) {
+ this.lessThan[0] = this.lessThan[0] + target;
+ this.lessThan[1] = this.lessThan[1] + (target * target);
+ this.lessThan[2] = this.lessThan[2] + 1;
+ if (this.left == null) {
+ this.left = new Node(val, target);
+ } else {
+ this.left.insertValue(val, target);
}
- }
+ } else {
+ this.greaterThan[0] = this.greaterThan[0] + target;
+ this.greaterThan[1] = this.greaterThan[1] + (target * target);
+ this.greaterThan[2] = this.greaterThan[2] + 1;
+ if (this.right == null) {
- public Node root1 = null;
-
- public void observeAttributeTarget(double attVal, double target){
- if (!Double.isNaN(attVal)) {
- if (this.root1 == null) {
- this.root1 = new Node(attVal, target);
- } else {
- this.root1.insertValue(attVal, target);
- }
+ this.right = new Node(val, target);
+ } else {
+ this.right.insertValue(val, target);
+ }
}
}
+ }
- @Override
- public void observeAttributeClass(double attVal, int classVal, double weight) {
-
- }
+ public Node root1 = null;
- @Override
- public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal) {
- return 0.0;
+ public void observeAttributeTarget(double attVal, double target) {
+ if (!Double.isNaN(attVal)) {
+ if (this.root1 == null) {
+ this.root1 = new Node(attVal, target);
+ } else {
+ this.root1.insertValue(attVal, target);
+ }
}
+ }
- @Override
- public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
- SplitCriterion criterion, double[] preSplitDist, int attIndex,
- boolean binaryOnly) {
- return searchForBestSplitOption(this.root1, null, null, null, null, false,
- criterion, preSplitDist, attIndex);
- }
+ @Override
+ public void observeAttributeClass(double attVal, int classVal, double weight) {
- protected AttributeSplitSuggestion searchForBestSplitOption(
- Node currentNode, AttributeSplitSuggestion currentBestOption,
- double[] actualParentLeft,
- double[] parentLeft, double[] parentRight, boolean leftChild,
- SplitCriterion criterion, double[] preSplitDist, int attIndex) {
-
- return currentBestOption;
- }
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- }
+ @Override
+ public double probabilityOfAttributeValueGivenClass(double attVal,
+ int classVal) {
+ return 0.0;
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- }
+ @Override
+ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
+ SplitCriterion criterion, double[] preSplitDist, int attIndex,
+ boolean binaryOnly) {
+ return searchForBestSplitOption(this.root1, null, null, null, null, false,
+ criterion, preSplitDist, attIndex);
+ }
+
+ protected AttributeSplitSuggestion searchForBestSplitOption(
+ Node currentNode, AttributeSplitSuggestion currentBestOption,
+ double[] actualParentLeft,
+ double[] parentLeft, double[] parentRight, boolean leftChild,
+ SplitCriterion criterion, double[] preSplitDist, int attIndex) {
+
+ return currentBestOption;
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ }
+
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ }
}
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java
index e756fcd..fe16447 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java
@@ -21,14 +21,14 @@
*/
/**
- * Interface for observing the class data distribution for a discrete (nominal) attribute.
- * This observer monitors the class distribution of a given attribute.
- * Used in naive Bayes and decision trees to monitor data statistics on leaves.
- *
+ * Interface for observing the class data distribution for a discrete (nominal)
+ * attribute. This observer monitors the class distribution of a given
+ * attribute. Used in naive Bayes and decision trees to monitor data statistics
+ * on leaves.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
- * @version $Revision: 7 $
+ * @version $Revision: 7 $
*/
public interface DiscreteAttributeClassObserver extends AttributeClassObserver {
-
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/FIMTDDNumericAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/FIMTDDNumericAttributeClassObserver.java
index 2434652..dac0b7d 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/FIMTDDNumericAttributeClassObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/FIMTDDNumericAttributeClassObserver.java
@@ -1,4 +1,3 @@
-
/* Project Knowledge Discovery from Data Streams, FCT LIAAD-INESC TEC,
*
* Contact: jgama@fep.up.pt
@@ -35,206 +34,220 @@
import com.yahoo.labs.samoa.moa.core.ObjectRepository;
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
-public class FIMTDDNumericAttributeClassObserver extends BinaryTreeNumericAttributeClassObserver implements NumericAttributeClassObserver {
+public class FIMTDDNumericAttributeClassObserver extends BinaryTreeNumericAttributeClassObserver implements
+ NumericAttributeClassObserver {
+
+ private static final long serialVersionUID = 1L;
+
+ protected class Node implements Serializable {
private static final long serialVersionUID = 1L;
- protected class Node implements Serializable {
+ // The split point to use
+ public double cut_point;
- private static final long serialVersionUID = 1L;
+ // E-BST statistics
+ public DoubleVector leftStatistics = new DoubleVector();
+ public DoubleVector rightStatistics = new DoubleVector();
- // The split point to use
- public double cut_point;
+ // Child nodes
+ public Node left;
+ public Node right;
- // E-BST statistics
- public DoubleVector leftStatistics = new DoubleVector();
- public DoubleVector rightStatistics = new DoubleVector();
-
- // Child nodes
- public Node left;
- public Node right;
-
- public Node(double val, double label, double weight) {
- this.cut_point = val;
- this.leftStatistics.addToValue(0, 1);
- this.leftStatistics.addToValue(1, label);
- this.leftStatistics.addToValue(2, label * label);
- }
-
- /**
- * Insert a new value into the tree, updating both the sum of values and
- * sum of squared values arrays
- */
- public void insertValue(double val, double label, double weight) {
-
- // If the new value equals the value stored in a node, update
- // the left (<=) node information
- if (val == this.cut_point) {
- this.leftStatistics.addToValue(0, 1);
- this.leftStatistics.addToValue(1, label);
- this.leftStatistics.addToValue(2, label * label);
- } // If the new value is less than the value in a node, update the
- // left distribution and send the value down to the left child node.
- // If no left child exists, create one
- else if (val <= this.cut_point) {
- this.leftStatistics.addToValue(0, 1);
- this.leftStatistics.addToValue(1, label);
- this.leftStatistics.addToValue(2, label * label);
- if (this.left == null) {
- this.left = new Node(val, label, weight);
- } else {
- this.left.insertValue(val, label, weight);
- }
- } // If the new value is greater than the value in a node, update the
- // right (>) distribution and send the value down to the right child node.
- // If no right child exists, create one
- else { // val > cut_point
- this.rightStatistics.addToValue(0, 1);
- this.rightStatistics.addToValue(1, label);
- this.rightStatistics.addToValue(2, label * label);
- if (this.right == null) {
- this.right = new Node(val, label, weight);
- } else {
- this.right.insertValue(val, label, weight);
- }
- }
- }
- }
-
- // Root node of the E-BST structure for this attribute
- public Node root = null;
-
- // Global variables for use in the FindBestSplit algorithm
- double sumTotalLeft;
- double sumTotalRight;
- double sumSqTotalLeft;
- double sumSqTotalRight;
- double countRightTotal;
- double countLeftTotal;
-
- public void observeAttributeClass(double attVal, double classVal, double weight) {
- if (!Double.isNaN(attVal)) {
- if (this.root == null) {
- this.root = new Node(attVal, classVal, weight);
- } else {
- this.root.insertValue(attVal, classVal, weight);
- }
- }
- }
-
- @Override
- public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal) {
- // TODO: NaiveBayes broken until implemented
- return 0.0;
- }
-
- @Override
- public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(SplitCriterion criterion, double[] preSplitDist, int attIndex, boolean binaryOnly) {
-
- // Initialise global variables
- sumTotalLeft = 0;
- sumTotalRight = preSplitDist[1];
- sumSqTotalLeft = 0;
- sumSqTotalRight = preSplitDist[2];
- countLeftTotal = 0;
- countRightTotal = preSplitDist[0];
- return searchForBestSplitOption(this.root, null, criterion, attIndex);
+ public Node(double val, double label, double weight) {
+ this.cut_point = val;
+ this.leftStatistics.addToValue(0, 1);
+ this.leftStatistics.addToValue(1, label);
+ this.leftStatistics.addToValue(2, label * label);
}
/**
- * Implementation of the FindBestSplit algorithm from E.Ikonomovska et al.
+ * Insert a new value into the tree, updating both the sum of values and sum
+ * of squared values arrays
*/
- protected AttributeSplitSuggestion searchForBestSplitOption(Node currentNode, AttributeSplitSuggestion currentBestOption, SplitCriterion criterion, int attIndex) {
- // Return null if the current node is null or we have finished looking through all the possible splits
- if (currentNode == null || countRightTotal == 0.0) {
- return currentBestOption;
+ public void insertValue(double val, double label, double weight) {
+
+ // If the new value equals the value stored in a node, update
+ // the left (<=) node information
+ if (val == this.cut_point) {
+ this.leftStatistics.addToValue(0, 1);
+ this.leftStatistics.addToValue(1, label);
+ this.leftStatistics.addToValue(2, label * label);
+ } // If the new value is less than the value in a node, update the
+ // left distribution and send the value down to the left child node.
+ // If no left child exists, create one
+ else if (val <= this.cut_point) {
+ this.leftStatistics.addToValue(0, 1);
+ this.leftStatistics.addToValue(1, label);
+ this.leftStatistics.addToValue(2, label * label);
+ if (this.left == null) {
+ this.left = new Node(val, label, weight);
+ } else {
+ this.left.insertValue(val, label, weight);
}
-
- if (currentNode.left != null) {
- currentBestOption = searchForBestSplitOption(currentNode.left, currentBestOption, criterion, attIndex);
+ } // If the new value is greater than the value in a node, update the
+ // right (>) distribution and send the value down to the right child node.
+ // If no right child exists, create one
+ else { // val > cut_point
+ this.rightStatistics.addToValue(0, 1);
+ this.rightStatistics.addToValue(1, label);
+ this.rightStatistics.addToValue(2, label * label);
+ if (this.right == null) {
+ this.right = new Node(val, label, weight);
+ } else {
+ this.right.insertValue(val, label, weight);
}
+ }
+ }
+ }
- sumTotalLeft += currentNode.leftStatistics.getValue(1);
- sumTotalRight -= currentNode.leftStatistics.getValue(1);
- sumSqTotalLeft += currentNode.leftStatistics.getValue(2);
- sumSqTotalRight -= currentNode.leftStatistics.getValue(2);
- countLeftTotal += currentNode.leftStatistics.getValue(0);
- countRightTotal -= currentNode.leftStatistics.getValue(0);
+ // Root node of the E-BST structure for this attribute
+ public Node root = null;
- double[][] postSplitDists = new double[][]{{countLeftTotal, sumTotalLeft, sumSqTotalLeft}, {countRightTotal, sumTotalRight, sumSqTotalRight}};
- double[] preSplitDist = new double[]{(countLeftTotal + countRightTotal), (sumTotalLeft + sumTotalRight), (sumSqTotalLeft + sumSqTotalRight)};
- double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
+ // Global variables for use in the FindBestSplit algorithm
+ double sumTotalLeft;
+ double sumTotalRight;
+ double sumSqTotalLeft;
+ double sumSqTotalRight;
+ double countRightTotal;
+ double countLeftTotal;
- if ((currentBestOption == null) || (merit > currentBestOption.merit)) {
- currentBestOption = new AttributeSplitSuggestion(
- new NumericAttributeBinaryTest(attIndex,
- currentNode.cut_point, true), postSplitDists, merit);
+ public void observeAttributeClass(double attVal, double classVal, double weight) {
+ if (!Double.isNaN(attVal)) {
+ if (this.root == null) {
+ this.root = new Node(attVal, classVal, weight);
+ } else {
+ this.root.insertValue(attVal, classVal, weight);
+ }
+ }
+ }
- }
+ @Override
+ public double probabilityOfAttributeValueGivenClass(double attVal,
+ int classVal) {
+ // TODO: NaiveBayes broken until implemented
+ return 0.0;
+ }
- if (currentNode.right != null) {
- currentBestOption = searchForBestSplitOption(currentNode.right, currentBestOption, criterion, attIndex);
- }
- sumTotalLeft -= currentNode.leftStatistics.getValue(1);
- sumTotalRight += currentNode.leftStatistics.getValue(1);
- sumSqTotalLeft -= currentNode.leftStatistics.getValue(2);
- sumSqTotalRight += currentNode.leftStatistics.getValue(2);
- countLeftTotal -= currentNode.leftStatistics.getValue(0);
- countRightTotal += currentNode.leftStatistics.getValue(0);
+ @Override
+ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(SplitCriterion criterion, double[] preSplitDist,
+ int attIndex, boolean binaryOnly) {
- return currentBestOption;
+ // Initialise global variables
+ sumTotalLeft = 0;
+ sumTotalRight = preSplitDist[1];
+ sumSqTotalLeft = 0;
+ sumSqTotalRight = preSplitDist[2];
+ countLeftTotal = 0;
+ countRightTotal = preSplitDist[0];
+ return searchForBestSplitOption(this.root, null, criterion, attIndex);
+ }
+
+ /**
+ * Implementation of the FindBestSplit algorithm from E.Ikonomovska et al.
+ */
+ protected AttributeSplitSuggestion searchForBestSplitOption(Node currentNode,
+ AttributeSplitSuggestion currentBestOption, SplitCriterion criterion, int attIndex) {
+ // Return null if the current node is null or we have finished looking
+ // through all the possible splits
+ if (currentNode == null || countRightTotal == 0.0) {
+ return currentBestOption;
}
- /**
- * A method to remove all nodes in the E-BST in which it and all it's
- * children represent 'bad' split points
- */
- public void removeBadSplits(SplitCriterion criterion, double lastCheckRatio, double lastCheckSDR, double lastCheckE) {
- removeBadSplitNodes(criterion, this.root, lastCheckRatio, lastCheckSDR, lastCheckE);
+ if (currentNode.left != null) {
+ currentBestOption = searchForBestSplitOption(currentNode.left, currentBestOption, criterion, attIndex);
}
- /**
- * Recursive method that first checks all of a node's children before
- * deciding if it is 'bad' and may be removed
- */
- private boolean removeBadSplitNodes(SplitCriterion criterion, Node currentNode, double lastCheckRatio, double lastCheckSDR, double lastCheckE) {
- boolean isBad = false;
+ sumTotalLeft += currentNode.leftStatistics.getValue(1);
+ sumTotalRight -= currentNode.leftStatistics.getValue(1);
+ sumSqTotalLeft += currentNode.leftStatistics.getValue(2);
+ sumSqTotalRight -= currentNode.leftStatistics.getValue(2);
+ countLeftTotal += currentNode.leftStatistics.getValue(0);
+ countRightTotal -= currentNode.leftStatistics.getValue(0);
- if (currentNode == null) {
- return true;
- }
+ double[][] postSplitDists = new double[][] { { countLeftTotal, sumTotalLeft, sumSqTotalLeft },
+ { countRightTotal, sumTotalRight, sumSqTotalRight } };
+ double[] preSplitDist = new double[] { (countLeftTotal + countRightTotal), (sumTotalLeft + sumTotalRight),
+ (sumSqTotalLeft + sumSqTotalRight) };
+ double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
- if (currentNode.left != null) {
- isBad = removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE);
- }
+ if ((currentBestOption == null) || (merit > currentBestOption.merit)) {
+ currentBestOption = new AttributeSplitSuggestion(
+ new NumericAttributeBinaryTest(attIndex,
+ currentNode.cut_point, true), postSplitDists, merit);
- if (currentNode.right != null && isBad) {
- isBad = removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE);
- }
-
- if (isBad) {
-
- double[][] postSplitDists = new double[][]{{currentNode.leftStatistics.getValue(0), currentNode.leftStatistics.getValue(1), currentNode.leftStatistics.getValue(2)}, {currentNode.rightStatistics.getValue(0), currentNode.rightStatistics.getValue(1), currentNode.rightStatistics.getValue(2)}};
- double[] preSplitDist = new double[]{(currentNode.leftStatistics.getValue(0) + currentNode.rightStatistics.getValue(0)), (currentNode.leftStatistics.getValue(1) + currentNode.rightStatistics.getValue(1)), (currentNode.leftStatistics.getValue(2) + currentNode.rightStatistics.getValue(2))};
- double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
-
- if ((merit / lastCheckSDR) < (lastCheckRatio - (2 * lastCheckE))) {
- currentNode = null;
- return true;
- }
- }
-
- return false;
}
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
+ if (currentNode.right != null) {
+ currentBestOption = searchForBestSplitOption(currentNode.right, currentBestOption, criterion, attIndex);
+ }
+ sumTotalLeft -= currentNode.leftStatistics.getValue(1);
+ sumTotalRight += currentNode.leftStatistics.getValue(1);
+ sumSqTotalLeft -= currentNode.leftStatistics.getValue(2);
+ sumSqTotalRight += currentNode.leftStatistics.getValue(2);
+ countLeftTotal -= currentNode.leftStatistics.getValue(0);
+ countRightTotal += currentNode.leftStatistics.getValue(0);
+
+ return currentBestOption;
+ }
+
+ /**
+ * A method to remove all nodes in the E-BST in which it and all it's children
+ * represent 'bad' split points
+ */
+ public void removeBadSplits(SplitCriterion criterion, double lastCheckRatio, double lastCheckSDR, double lastCheckE) {
+ removeBadSplitNodes(criterion, this.root, lastCheckRatio, lastCheckSDR, lastCheckE);
+ }
+
+ /**
+ * Recursive method that first checks all of a node's children before deciding
+ * if it is 'bad' and may be removed
+ */
+ private boolean removeBadSplitNodes(SplitCriterion criterion, Node currentNode, double lastCheckRatio,
+ double lastCheckSDR, double lastCheckE) {
+ boolean isBad = false;
+
+ if (currentNode == null) {
+ return true;
}
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- // TODO Auto-generated method stub
+ if (currentNode.left != null) {
+ isBad = removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE);
}
+
+ if (currentNode.right != null && isBad) {
+ isBad = removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE);
+ }
+
+ if (isBad) {
+
+ double[][] postSplitDists = new double[][] {
+ { currentNode.leftStatistics.getValue(0), currentNode.leftStatistics.getValue(1),
+ currentNode.leftStatistics.getValue(2) },
+ { currentNode.rightStatistics.getValue(0), currentNode.rightStatistics.getValue(1),
+ currentNode.rightStatistics.getValue(2) } };
+ double[] preSplitDist = new double[] {
+ (currentNode.leftStatistics.getValue(0) + currentNode.rightStatistics.getValue(0)),
+ (currentNode.leftStatistics.getValue(1) + currentNode.rightStatistics.getValue(1)),
+ (currentNode.leftStatistics.getValue(2) + currentNode.rightStatistics.getValue(2)) };
+ double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
+
+ if ((merit / lastCheckSDR) < (lastCheckRatio - (2 * lastCheckE))) {
+ currentNode = null;
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
+
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java
index 21f58b1..107fa29 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java
@@ -37,146 +37,147 @@
import com.github.javacliparser.IntOption;
/**
- * Class for observing the class data distribution for a numeric attribute using gaussian estimators.
- * This observer monitors the class distribution of a given attribute.
- * Used in naive Bayes and decision trees to monitor data statistics on leaves.
- *
+ * Class for observing the class data distribution for a numeric attribute using
+ * gaussian estimators. This observer monitors the class distribution of a given
+ * attribute. Used in naive Bayes and decision trees to monitor data statistics
+ * on leaves.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class GaussianNumericAttributeClassObserver extends AbstractOptionHandler
- implements NumericAttributeClassObserver {
+ implements NumericAttributeClassObserver {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected DoubleVector minValueObservedPerClass = new DoubleVector();
+ protected DoubleVector minValueObservedPerClass = new DoubleVector();
- protected DoubleVector maxValueObservedPerClass = new DoubleVector();
+ protected DoubleVector maxValueObservedPerClass = new DoubleVector();
- protected AutoExpandVector<GaussianEstimator> attValDistPerClass = new AutoExpandVector<>();
+ protected AutoExpandVector<GaussianEstimator> attValDistPerClass = new AutoExpandVector<>();
- /**
- * @param classVal
- * @return The requested Estimator if it exists, or null if not present.
- */
- public GaussianEstimator getEstimator(int classVal) {
- return this.attValDistPerClass.get(classVal);
- }
-
- public IntOption numBinsOption = new IntOption("numBins", 'n',
- "The number of bins.", 10, 1, Integer.MAX_VALUE);
+ /**
+ * @param classVal
+ * @return The requested Estimator if it exists, or null if not present.
+ */
+ public GaussianEstimator getEstimator(int classVal) {
+ return this.attValDistPerClass.get(classVal);
+ }
- @Override
- public void observeAttributeClass(double attVal, int classVal, double weight) {
- if (!Utils.isMissingValue(attVal)) {
- GaussianEstimator valDist = this.attValDistPerClass.get(classVal);
- if (valDist == null) {
- valDist = new GaussianEstimator();
- this.attValDistPerClass.set(classVal, valDist);
- this.minValueObservedPerClass.setValue(classVal, attVal);
- this.maxValueObservedPerClass.setValue(classVal, attVal);
- } else {
- if (attVal < this.minValueObservedPerClass.getValue(classVal)) {
- this.minValueObservedPerClass.setValue(classVal, attVal);
- }
- if (attVal > this.maxValueObservedPerClass.getValue(classVal)) {
- this.maxValueObservedPerClass.setValue(classVal, attVal);
- }
- }
- valDist.addObservation(attVal, weight);
+ public IntOption numBinsOption = new IntOption("numBins", 'n',
+ "The number of bins.", 10, 1, Integer.MAX_VALUE);
+
+ @Override
+ public void observeAttributeClass(double attVal, int classVal, double weight) {
+ if (!Utils.isMissingValue(attVal)) {
+ GaussianEstimator valDist = this.attValDistPerClass.get(classVal);
+ if (valDist == null) {
+ valDist = new GaussianEstimator();
+ this.attValDistPerClass.set(classVal, valDist);
+ this.minValueObservedPerClass.setValue(classVal, attVal);
+ this.maxValueObservedPerClass.setValue(classVal, attVal);
+ } else {
+ if (attVal < this.minValueObservedPerClass.getValue(classVal)) {
+ this.minValueObservedPerClass.setValue(classVal, attVal);
}
- }
-
- @Override
- public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal) {
- GaussianEstimator obs = this.attValDistPerClass.get(classVal);
- return obs != null ? obs.probabilityDensity(attVal) : 0.0;
- }
-
- @Override
- public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
- SplitCriterion criterion, double[] preSplitDist, int attIndex,
- boolean binaryOnly) {
- AttributeSplitSuggestion bestSuggestion = null;
- double[] suggestedSplitValues = getSplitPointSuggestions();
- for (double splitValue : suggestedSplitValues) {
- double[][] postSplitDists = getClassDistsResultingFromBinarySplit(splitValue);
- double merit = criterion.getMeritOfSplit(preSplitDist,
- postSplitDists);
- if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) {
- bestSuggestion = new AttributeSplitSuggestion(
- new NumericAttributeBinaryTest(attIndex, splitValue,
- true), postSplitDists, merit);
- }
+ if (attVal > this.maxValueObservedPerClass.getValue(classVal)) {
+ this.maxValueObservedPerClass.setValue(classVal, attVal);
}
- return bestSuggestion;
+ }
+ valDist.addObservation(attVal, weight);
}
+ }
- public double[] getSplitPointSuggestions() {
- Set<Double> suggestedSplitValues = new TreeSet<>();
- double minValue = Double.POSITIVE_INFINITY;
- double maxValue = Double.NEGATIVE_INFINITY;
- for (int i = 0; i < this.attValDistPerClass.size(); i++) {
- GaussianEstimator estimator = this.attValDistPerClass.get(i);
- if (estimator != null) {
- if (this.minValueObservedPerClass.getValue(i) < minValue) {
- minValue = this.minValueObservedPerClass.getValue(i);
- }
- if (this.maxValueObservedPerClass.getValue(i) > maxValue) {
- maxValue = this.maxValueObservedPerClass.getValue(i);
- }
- }
+ @Override
+ public double probabilityOfAttributeValueGivenClass(double attVal,
+ int classVal) {
+ GaussianEstimator obs = this.attValDistPerClass.get(classVal);
+ return obs != null ? obs.probabilityDensity(attVal) : 0.0;
+ }
+
+ @Override
+ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
+ SplitCriterion criterion, double[] preSplitDist, int attIndex,
+ boolean binaryOnly) {
+ AttributeSplitSuggestion bestSuggestion = null;
+ double[] suggestedSplitValues = getSplitPointSuggestions();
+ for (double splitValue : suggestedSplitValues) {
+ double[][] postSplitDists = getClassDistsResultingFromBinarySplit(splitValue);
+ double merit = criterion.getMeritOfSplit(preSplitDist,
+ postSplitDists);
+ if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) {
+ bestSuggestion = new AttributeSplitSuggestion(
+ new NumericAttributeBinaryTest(attIndex, splitValue,
+ true), postSplitDists, merit);
+ }
+ }
+ return bestSuggestion;
+ }
+
+ public double[] getSplitPointSuggestions() {
+ Set<Double> suggestedSplitValues = new TreeSet<>();
+ double minValue = Double.POSITIVE_INFINITY;
+ double maxValue = Double.NEGATIVE_INFINITY;
+ for (int i = 0; i < this.attValDistPerClass.size(); i++) {
+ GaussianEstimator estimator = this.attValDistPerClass.get(i);
+ if (estimator != null) {
+ if (this.minValueObservedPerClass.getValue(i) < minValue) {
+ minValue = this.minValueObservedPerClass.getValue(i);
}
- if (minValue < Double.POSITIVE_INFINITY) {
- double range = maxValue - minValue;
- for (int i = 0; i < this.numBinsOption.getValue(); i++) {
- double splitValue = range / (this.numBinsOption.getValue() + 1.0) * (i + 1)
- + minValue;
- if ((splitValue > minValue) && (splitValue < maxValue)) {
- suggestedSplitValues.add(splitValue);
- }
- }
+ if (this.maxValueObservedPerClass.getValue(i) > maxValue) {
+ maxValue = this.maxValueObservedPerClass.getValue(i);
}
- double[] suggestions = new double[suggestedSplitValues.size()];
- int i = 0;
- for (double suggestion : suggestedSplitValues) {
- suggestions[i++] = suggestion;
+ }
+ }
+ if (minValue < Double.POSITIVE_INFINITY) {
+ double range = maxValue - minValue;
+ for (int i = 0; i < this.numBinsOption.getValue(); i++) {
+ double splitValue = range / (this.numBinsOption.getValue() + 1.0) * (i + 1)
+ + minValue;
+ if ((splitValue > minValue) && (splitValue < maxValue)) {
+ suggestedSplitValues.add(splitValue);
}
- return suggestions;
+ }
}
+ double[] suggestions = new double[suggestedSplitValues.size()];
+ int i = 0;
+ for (double suggestion : suggestedSplitValues) {
+ suggestions[i++] = suggestion;
+ }
+ return suggestions;
+ }
- // assume all values equal to splitValue go to lhs
- public double[][] getClassDistsResultingFromBinarySplit(double splitValue) {
- DoubleVector lhsDist = new DoubleVector();
- DoubleVector rhsDist = new DoubleVector();
- for (int i = 0; i < this.attValDistPerClass.size(); i++) {
- GaussianEstimator estimator = this.attValDistPerClass.get(i);
- if (estimator != null) {
- if (splitValue < this.minValueObservedPerClass.getValue(i)) {
- rhsDist.addToValue(i, estimator.getTotalWeightObserved());
- } else if (splitValue >= this.maxValueObservedPerClass.getValue(i)) {
- lhsDist.addToValue(i, estimator.getTotalWeightObserved());
- } else {
- double[] weightDist = estimator.estimatedWeight_LessThan_EqualTo_GreaterThan_Value(splitValue);
- lhsDist.addToValue(i, weightDist[0] + weightDist[1]);
- rhsDist.addToValue(i, weightDist[2]);
- }
- }
+ // assume all values equal to splitValue go to lhs
+ public double[][] getClassDistsResultingFromBinarySplit(double splitValue) {
+ DoubleVector lhsDist = new DoubleVector();
+ DoubleVector rhsDist = new DoubleVector();
+ for (int i = 0; i < this.attValDistPerClass.size(); i++) {
+ GaussianEstimator estimator = this.attValDistPerClass.get(i);
+ if (estimator != null) {
+ if (splitValue < this.minValueObservedPerClass.getValue(i)) {
+ rhsDist.addToValue(i, estimator.getTotalWeightObserved());
+ } else if (splitValue >= this.maxValueObservedPerClass.getValue(i)) {
+ lhsDist.addToValue(i, estimator.getTotalWeightObserved());
+ } else {
+ double[] weightDist = estimator.estimatedWeight_LessThan_EqualTo_GreaterThan_Value(splitValue);
+ lhsDist.addToValue(i, weightDist[0] + weightDist[1]);
+ rhsDist.addToValue(i, weightDist[2]);
}
- return new double[][]{lhsDist.getArrayRef(), rhsDist.getArrayRef()};
+ }
}
+ return new double[][] { lhsDist.getArrayRef(), rhsDist.getArrayRef() };
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ }
- @Override
- public void observeAttributeTarget(double attVal, double target) {
- throw new UnsupportedOperationException("Not supported yet.");
- }
+ @Override
+ public void observeAttributeTarget(double attVal, double target) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/GreenwaldKhannaNumericAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/GreenwaldKhannaNumericAttributeClassObserver.java
index 3de1146..9aaf0b8 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/GreenwaldKhannaNumericAttributeClassObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/GreenwaldKhannaNumericAttributeClassObserver.java
@@ -34,93 +34,95 @@
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
/**
- * Class for observing the class data distribution for a numeric attribute using Greenwald and Khanna methodology.
- * This observer monitors the class distribution of a given attribute.
- * Used in naive Bayes and decision trees to monitor data statistics on leaves.
- *
+ * Class for observing the class data distribution for a numeric attribute using
+ * Greenwald and Khanna methodology. This observer monitors the class
+ * distribution of a given attribute. Used in naive Bayes and decision trees to
+ * monitor data statistics on leaves.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
-public class GreenwaldKhannaNumericAttributeClassObserver extends AbstractOptionHandler implements NumericAttributeClassObserver {
+public class GreenwaldKhannaNumericAttributeClassObserver extends AbstractOptionHandler implements
+ NumericAttributeClassObserver {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected AutoExpandVector<GreenwaldKhannaQuantileSummary> attValDistPerClass = new AutoExpandVector<>();
+ protected AutoExpandVector<GreenwaldKhannaQuantileSummary> attValDistPerClass = new AutoExpandVector<>();
- public IntOption numTuplesOption = new IntOption("numTuples", 'n',
- "The number of tuples.", 10, 1, Integer.MAX_VALUE);
+ public IntOption numTuplesOption = new IntOption("numTuples", 'n',
+ "The number of tuples.", 10, 1, Integer.MAX_VALUE);
- @Override
- public void observeAttributeClass(double attVal, int classVal, double weight) {
- if (!Utils.isMissingValue(attVal)) {
- GreenwaldKhannaQuantileSummary valDist = this.attValDistPerClass.get(classVal);
- if (valDist == null) {
- valDist = new GreenwaldKhannaQuantileSummary(this.numTuplesOption.getValue());
- this.attValDistPerClass.set(classVal, valDist);
- }
- // TODO: not taking weight into account
- valDist.insert(attVal);
+ @Override
+ public void observeAttributeClass(double attVal, int classVal, double weight) {
+ if (!Utils.isMissingValue(attVal)) {
+ GreenwaldKhannaQuantileSummary valDist = this.attValDistPerClass.get(classVal);
+ if (valDist == null) {
+ valDist = new GreenwaldKhannaQuantileSummary(this.numTuplesOption.getValue());
+ this.attValDistPerClass.set(classVal, valDist);
+ }
+ // TODO: not taking weight into account
+ valDist.insert(attVal);
+ }
+ }
+
+ @Override
+ public double probabilityOfAttributeValueGivenClass(double attVal,
+ int classVal) {
+ // TODO: NaiveBayes broken until implemented
+ return 0.0;
+ }
+
+ @Override
+ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
+ SplitCriterion criterion, double[] preSplitDist, int attIndex,
+ boolean binaryOnly) {
+ AttributeSplitSuggestion bestSuggestion = null;
+ for (GreenwaldKhannaQuantileSummary qs : this.attValDistPerClass) {
+ if (qs != null) {
+ double[] cutpoints = qs.getSuggestedCutpoints();
+ for (double cutpoint : cutpoints) {
+ double[][] postSplitDists = getClassDistsResultingFromBinarySplit(cutpoint);
+ double merit = criterion.getMeritOfSplit(preSplitDist,
+ postSplitDists);
+ if ((bestSuggestion == null)
+ || (merit > bestSuggestion.merit)) {
+ bestSuggestion = new AttributeSplitSuggestion(
+ new NumericAttributeBinaryTest(attIndex,
+ cutpoint, true), postSplitDists, merit);
+ }
}
+ }
}
+ return bestSuggestion;
+ }
- @Override
- public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal) {
- // TODO: NaiveBayes broken until implemented
- return 0.0;
+ // assume all values equal to splitValue go to lhs
+ public double[][] getClassDistsResultingFromBinarySplit(double splitValue) {
+ DoubleVector lhsDist = new DoubleVector();
+ DoubleVector rhsDist = new DoubleVector();
+ for (int i = 0; i < this.attValDistPerClass.size(); i++) {
+ GreenwaldKhannaQuantileSummary estimator = this.attValDistPerClass.get(i);
+ if (estimator != null) {
+ long countBelow = estimator.getCountBelow(splitValue);
+ lhsDist.addToValue(i, countBelow);
+ rhsDist.addToValue(i, estimator.getTotalCount() - countBelow);
+ }
}
+ return new double[][] { lhsDist.getArrayRef(), rhsDist.getArrayRef() };
+ }
- @Override
- public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
- SplitCriterion criterion, double[] preSplitDist, int attIndex,
- boolean binaryOnly) {
- AttributeSplitSuggestion bestSuggestion = null;
- for (GreenwaldKhannaQuantileSummary qs : this.attValDistPerClass) {
- if (qs != null) {
- double[] cutpoints = qs.getSuggestedCutpoints();
- for (double cutpoint : cutpoints) {
- double[][] postSplitDists = getClassDistsResultingFromBinarySplit(cutpoint);
- double merit = criterion.getMeritOfSplit(preSplitDist,
- postSplitDists);
- if ((bestSuggestion == null)
- || (merit > bestSuggestion.merit)) {
- bestSuggestion = new AttributeSplitSuggestion(
- new NumericAttributeBinaryTest(attIndex,
- cutpoint, true), postSplitDists, merit);
- }
- }
- }
- }
- return bestSuggestion;
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- // assume all values equal to splitValue go to lhs
- public double[][] getClassDistsResultingFromBinarySplit(double splitValue) {
- DoubleVector lhsDist = new DoubleVector();
- DoubleVector rhsDist = new DoubleVector();
- for (int i = 0; i < this.attValDistPerClass.size(); i++) {
- GreenwaldKhannaQuantileSummary estimator = this.attValDistPerClass.get(i);
- if (estimator != null) {
- long countBelow = estimator.getCountBelow(splitValue);
- lhsDist.addToValue(i, countBelow);
- rhsDist.addToValue(i, estimator.getTotalCount() - countBelow);
- }
- }
- return new double[][]{lhsDist.getArrayRef(), rhsDist.getArrayRef()};
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
-
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
-
- @Override
- public void observeAttributeTarget(double attVal, double target) {
- throw new UnsupportedOperationException("Not supported yet.");
- }
+ @Override
+ public void observeAttributeTarget(double attVal, double target) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java
index d605e84..7b34fa4 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NominalAttributeClassObserver.java
@@ -33,146 +33,146 @@
import com.yahoo.labs.samoa.moa.options.AbstractOptionHandler;
/**
- * Class for observing the class data distribution for a nominal attribute.
- * This observer monitors the class distribution of a given attribute.
- * Used in naive Bayes and decision trees to monitor data statistics on leaves.
- *
+ * Class for observing the class data distribution for a nominal attribute. This
+ * observer monitors the class distribution of a given attribute. Used in naive
+ * Bayes and decision trees to monitor data statistics on leaves.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class NominalAttributeClassObserver extends AbstractOptionHandler implements DiscreteAttributeClassObserver {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected double totalWeightObserved = 0.0;
+ protected double totalWeightObserved = 0.0;
- protected double missingWeightObserved = 0.0;
+ protected double missingWeightObserved = 0.0;
- public AutoExpandVector<DoubleVector> attValDistPerClass = new AutoExpandVector<>();
+ public AutoExpandVector<DoubleVector> attValDistPerClass = new AutoExpandVector<>();
- @Override
- public void observeAttributeClass(double attVal, int classVal, double weight) {
- if (Utils.isMissingValue(attVal)) {
- this.missingWeightObserved += weight;
- } else {
- int attValInt = (int) attVal;
- DoubleVector valDist = this.attValDistPerClass.get(classVal);
- if (valDist == null) {
- valDist = new DoubleVector();
- this.attValDistPerClass.set(classVal, valDist);
- }
- valDist.addToValue(attValInt, weight);
+ @Override
+ public void observeAttributeClass(double attVal, int classVal, double weight) {
+ if (Utils.isMissingValue(attVal)) {
+ this.missingWeightObserved += weight;
+ } else {
+ int attValInt = (int) attVal;
+ DoubleVector valDist = this.attValDistPerClass.get(classVal);
+ if (valDist == null) {
+ valDist = new DoubleVector();
+ this.attValDistPerClass.set(classVal, valDist);
+ }
+ valDist.addToValue(attValInt, weight);
+ }
+ this.totalWeightObserved += weight;
+ }
+
+ @Override
+ public double probabilityOfAttributeValueGivenClass(double attVal,
+ int classVal) {
+ DoubleVector obs = this.attValDistPerClass.get(classVal);
+ return obs != null ? (obs.getValue((int) attVal) + 1.0)
+ / (obs.sumOfValues() + obs.numValues()) : 0.0;
+ }
+
+ public double totalWeightOfClassObservations() {
+ return this.totalWeightObserved;
+ }
+
+ public double weightOfObservedMissingValues() {
+ return this.missingWeightObserved;
+ }
+
+ @Override
+ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
+ SplitCriterion criterion, double[] preSplitDist, int attIndex,
+ boolean binaryOnly) {
+ AttributeSplitSuggestion bestSuggestion = null;
+ int maxAttValsObserved = getMaxAttValsObserved();
+ if (!binaryOnly) {
+ double[][] postSplitDists = getClassDistsResultingFromMultiwaySplit(maxAttValsObserved);
+ double merit = criterion.getMeritOfSplit(preSplitDist,
+ postSplitDists);
+ bestSuggestion = new AttributeSplitSuggestion(
+ new NominalAttributeMultiwayTest(attIndex), postSplitDists,
+ merit);
+ }
+ for (int valIndex = 0; valIndex < maxAttValsObserved; valIndex++) {
+ double[][] postSplitDists = getClassDistsResultingFromBinarySplit(valIndex);
+ double merit = criterion.getMeritOfSplit(preSplitDist,
+ postSplitDists);
+ if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) {
+ bestSuggestion = new AttributeSplitSuggestion(
+ new NominalAttributeBinaryTest(attIndex, valIndex),
+ postSplitDists, merit);
+ }
+ }
+ return bestSuggestion;
+ }
+
+ public int getMaxAttValsObserved() {
+ int maxAttValsObserved = 0;
+ for (DoubleVector attValDist : this.attValDistPerClass) {
+ if ((attValDist != null)
+ && (attValDist.numValues() > maxAttValsObserved)) {
+ maxAttValsObserved = attValDist.numValues();
+ }
+ }
+ return maxAttValsObserved;
+ }
+
+ public double[][] getClassDistsResultingFromMultiwaySplit(
+ int maxAttValsObserved) {
+ DoubleVector[] resultingDists = new DoubleVector[maxAttValsObserved];
+ for (int i = 0; i < resultingDists.length; i++) {
+ resultingDists[i] = new DoubleVector();
+ }
+ for (int i = 0; i < this.attValDistPerClass.size(); i++) {
+ DoubleVector attValDist = this.attValDistPerClass.get(i);
+ if (attValDist != null) {
+ for (int j = 0; j < attValDist.numValues(); j++) {
+ resultingDists[j].addToValue(i, attValDist.getValue(j));
}
- this.totalWeightObserved += weight;
+ }
}
-
- @Override
- public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal) {
- DoubleVector obs = this.attValDistPerClass.get(classVal);
- return obs != null ? (obs.getValue((int) attVal) + 1.0)
- / (obs.sumOfValues() + obs.numValues()) : 0.0;
+ double[][] distributions = new double[maxAttValsObserved][];
+ for (int i = 0; i < distributions.length; i++) {
+ distributions[i] = resultingDists[i].getArrayRef();
}
+ return distributions;
+ }
- public double totalWeightOfClassObservations() {
- return this.totalWeightObserved;
- }
-
- public double weightOfObservedMissingValues() {
- return this.missingWeightObserved;
- }
-
- @Override
- public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
- SplitCriterion criterion, double[] preSplitDist, int attIndex,
- boolean binaryOnly) {
- AttributeSplitSuggestion bestSuggestion = null;
- int maxAttValsObserved = getMaxAttValsObserved();
- if (!binaryOnly) {
- double[][] postSplitDists = getClassDistsResultingFromMultiwaySplit(maxAttValsObserved);
- double merit = criterion.getMeritOfSplit(preSplitDist,
- postSplitDists);
- bestSuggestion = new AttributeSplitSuggestion(
- new NominalAttributeMultiwayTest(attIndex), postSplitDists,
- merit);
+ public double[][] getClassDistsResultingFromBinarySplit(int valIndex) {
+ DoubleVector equalsDist = new DoubleVector();
+ DoubleVector notEqualDist = new DoubleVector();
+ for (int i = 0; i < this.attValDistPerClass.size(); i++) {
+ DoubleVector attValDist = this.attValDistPerClass.get(i);
+ if (attValDist != null) {
+ for (int j = 0; j < attValDist.numValues(); j++) {
+ if (j == valIndex) {
+ equalsDist.addToValue(i, attValDist.getValue(j));
+ } else {
+ notEqualDist.addToValue(i, attValDist.getValue(j));
+ }
}
- for (int valIndex = 0; valIndex < maxAttValsObserved; valIndex++) {
- double[][] postSplitDists = getClassDistsResultingFromBinarySplit(valIndex);
- double merit = criterion.getMeritOfSplit(preSplitDist,
- postSplitDists);
- if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) {
- bestSuggestion = new AttributeSplitSuggestion(
- new NominalAttributeBinaryTest(attIndex, valIndex),
- postSplitDists, merit);
- }
- }
- return bestSuggestion;
+ }
}
+ return new double[][] { equalsDist.getArrayRef(),
+ notEqualDist.getArrayRef() };
+ }
- public int getMaxAttValsObserved() {
- int maxAttValsObserved = 0;
- for (DoubleVector attValDist : this.attValDistPerClass) {
- if ((attValDist != null)
- && (attValDist.numValues() > maxAttValsObserved)) {
- maxAttValsObserved = attValDist.numValues();
- }
- }
- return maxAttValsObserved;
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- public double[][] getClassDistsResultingFromMultiwaySplit(
- int maxAttValsObserved) {
- DoubleVector[] resultingDists = new DoubleVector[maxAttValsObserved];
- for (int i = 0; i < resultingDists.length; i++) {
- resultingDists[i] = new DoubleVector();
- }
- for (int i = 0; i < this.attValDistPerClass.size(); i++) {
- DoubleVector attValDist = this.attValDistPerClass.get(i);
- if (attValDist != null) {
- for (int j = 0; j < attValDist.numValues(); j++) {
- resultingDists[j].addToValue(i, attValDist.getValue(j));
- }
- }
- }
- double[][] distributions = new double[maxAttValsObserved][];
- for (int i = 0; i < distributions.length; i++) {
- distributions[i] = resultingDists[i].getArrayRef();
- }
- return distributions;
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
- public double[][] getClassDistsResultingFromBinarySplit(int valIndex) {
- DoubleVector equalsDist = new DoubleVector();
- DoubleVector notEqualDist = new DoubleVector();
- for (int i = 0; i < this.attValDistPerClass.size(); i++) {
- DoubleVector attValDist = this.attValDistPerClass.get(i);
- if (attValDist != null) {
- for (int j = 0; j < attValDist.numValues(); j++) {
- if (j == valIndex) {
- equalsDist.addToValue(i, attValDist.getValue(j));
- } else {
- notEqualDist.addToValue(i, attValDist.getValue(j));
- }
- }
- }
- }
- return new double[][]{equalsDist.getArrayRef(),
- notEqualDist.getArrayRef()};
- }
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
-
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
-
- @Override
- public void observeAttributeTarget(double attVal, double target) {
- throw new UnsupportedOperationException("Not supported yet.");
- }
+ @Override
+ public void observeAttributeTarget(double attVal, double target) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NullAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NullAttributeClassObserver.java
index def0666..ab14d97 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NullAttributeClassObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NullAttributeClassObserver.java
@@ -27,54 +27,54 @@
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
/**
- * Class for observing the class data distribution for a null attribute.
- * This method is used to disable the observation for an attribute.
- * Used in decision trees to monitor data statistics on leaves.
- *
+ * Class for observing the class data distribution for a null attribute. This
+ * method is used to disable the observation for an attribute. Used in decision
+ * trees to monitor data statistics on leaves.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class NullAttributeClassObserver extends AbstractOptionHandler implements AttributeClassObserver {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- @Override
- public void observeAttributeClass(double attVal, int classVal, double weight) {
- }
+ @Override
+ public void observeAttributeClass(double attVal, int classVal, double weight) {
+ }
- @Override
- public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal) {
- return 0.0;
- }
+ @Override
+ public double probabilityOfAttributeValueGivenClass(double attVal,
+ int classVal) {
+ return 0.0;
+ }
- public double totalWeightOfClassObservations() {
- return 0.0;
- }
+ public double totalWeightOfClassObservations() {
+ return 0.0;
+ }
- public double weightOfObservedMissingValues() {
- return 0.0;
- }
+ public double weightOfObservedMissingValues() {
+ return 0.0;
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
- SplitCriterion criterion, double[] preSplitDist, int attIndex,
- boolean binaryOnly) {
- return null;
- }
+ @Override
+ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
+ SplitCriterion criterion, double[] preSplitDist, int attIndex,
+ boolean binaryOnly) {
+ return null;
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
- @Override
- public void observeAttributeTarget(double attVal, double target) {
- throw new UnsupportedOperationException("Not supported yet.");
- }
+ @Override
+ public void observeAttributeTarget(double attVal, double target) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NumericAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NumericAttributeClassObserver.java
index 1660d5f..ca7a50a 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NumericAttributeClassObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/NumericAttributeClassObserver.java
@@ -22,13 +22,12 @@
/**
* Interface for observing the class data distribution for a numeric attribute.
- * This observer monitors the class distribution of a given attribute.
- * Used in naive Bayes and decision trees to monitor data statistics on leaves.
- *
+ * This observer monitors the class distribution of a given attribute. Used in
+ * naive Bayes and decision trees to monitor data statistics on leaves.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
- * @version $Revision: 7 $
+ * @version $Revision: 7 $
*/
public interface NumericAttributeClassObserver extends AttributeClassObserver {
-
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/VFMLNumericAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/VFMLNumericAttributeClassObserver.java
index c7f3d7b..9650532 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/VFMLNumericAttributeClassObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/VFMLNumericAttributeClassObserver.java
@@ -36,188 +36,188 @@
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
/**
- * Class for observing the class data distribution for a numeric attribute as in VFML.
- * Used in naive Bayes and decision trees to monitor data statistics on leaves.
- *
+ * Class for observing the class data distribution for a numeric attribute as in
+ * VFML. Used in naive Bayes and decision trees to monitor data statistics on
+ * leaves.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class VFMLNumericAttributeClassObserver extends AbstractOptionHandler implements NumericAttributeClassObserver {
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public void observeAttributeTarget(double attVal, double target) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
+
+ protected class Bin implements Serializable {
+
private static final long serialVersionUID = 1L;
- @Override
- public void observeAttributeTarget(double attVal, double target) {
- throw new UnsupportedOperationException("Not supported yet.");
- }
+ public double lowerBound, upperBound;
- protected class Bin implements Serializable {
+ public DoubleVector classWeights = new DoubleVector();
- private static final long serialVersionUID = 1L;
+ public int boundaryClass;
- public double lowerBound, upperBound;
+ public double boundaryWeight;
+ }
- public DoubleVector classWeights = new DoubleVector();
+ protected List<Bin> binList = new ArrayList<>();
- public int boundaryClass;
+ public IntOption numBinsOption = new IntOption("numBins", 'n',
+ "The number of bins.", 10, 1, Integer.MAX_VALUE);
- public double boundaryWeight;
- }
-
- protected List<Bin> binList = new ArrayList<>();
-
- public IntOption numBinsOption = new IntOption("numBins", 'n',
- "The number of bins.", 10, 1, Integer.MAX_VALUE);
-
-
- @Override
- public void observeAttributeClass(double attVal, int classVal, double weight) {
- if (!Utils.isMissingValue(attVal)) {
- if (this.binList.size() < 1) {
- // create the first bin
- Bin newBin = new Bin();
- newBin.classWeights.addToValue(classVal, weight);
- newBin.boundaryClass = classVal;
- newBin.boundaryWeight = weight;
- newBin.upperBound = attVal;
- newBin.lowerBound = attVal;
- this.binList.add(newBin);
- } else {
- // find bin containing new example with binary search
- int index = 0;
- boolean found = false;
- int min = 0;
- int max = this.binList.size() - 1;
- while ((min <= max) && !found) {
- int i = (min + max) / 2;
- Bin bin = this.binList.get(i);
- if (((attVal >= bin.lowerBound) && (attVal < bin.upperBound))
- || ((i == this.binList.size() - 1)
- && (attVal >= bin.lowerBound) && (attVal <= bin.upperBound))) {
- found = true;
- index = i;
- } else if (attVal < bin.lowerBound) {
- max = i - 1;
- } else {
- min = i + 1;
- }
- }
- boolean first = false;
- boolean last = false;
- if (!found) {
- // determine if it is before or after the existing range
- Bin bin = this.binList.get(0);
- if (bin.lowerBound > attVal) {
- // go before the first bin
- index = 0;
- first = true;
- } else {
- // if we haven't found it yet value must be > last bins
- // upperBound
- index = this.binList.size() - 1;
- last = true;
- }
- }
- Bin bin = this.binList.get(index); // VLIndex(ct->bins, index);
- if ((bin.lowerBound == attVal)
- || (this.binList.size() >= this.numBinsOption.getValue())) {// Option.getValue())
- // {//1000)
- // {
- // if this is the exact same boundary and class as the bin
- // boundary or we aren't adding new bins any more then
- // increment
- // boundary counts
- bin.classWeights.addToValue(classVal, weight);
- if ((bin.boundaryClass == classVal)
- && (bin.lowerBound == attVal)) {
- // if it is also the same class then special case it
- bin.boundaryWeight += weight;
- }
- } else {
- // create a new bin
- Bin newBin = new Bin();
- newBin.classWeights.addToValue(classVal, weight);
- newBin.boundaryWeight = weight;
- newBin.boundaryClass = classVal;
- newBin.upperBound = bin.upperBound;
- newBin.lowerBound = attVal;
-
- double percent = 0.0;
- // estimate initial counts with a linear interpolation
- if (!((bin.upperBound - bin.lowerBound == 0) || last || first)) {
- percent = 1.0 - ((attVal - bin.lowerBound) / (bin.upperBound - bin.lowerBound));
- }
-
- // take out the boundry points, they stay with the old bin
- bin.classWeights.addToValue(bin.boundaryClass,
- -bin.boundaryWeight);
- DoubleVector weightToShift = new DoubleVector(
- bin.classWeights);
- weightToShift.scaleValues(percent);
- newBin.classWeights.addValues(weightToShift);
- bin.classWeights.subtractValues(weightToShift);
- // put the boundry examples back in
- bin.classWeights.addToValue(bin.boundaryClass,
- bin.boundaryWeight);
-
- // insert the new bin in the right place
- if (last) {
- bin.upperBound = attVal;
- newBin.upperBound = attVal;
- this.binList.add(newBin);
- } else if (first) {
- newBin.upperBound = bin.lowerBound;
- this.binList.add(0, newBin);
- } else {
- newBin.upperBound = bin.upperBound;
- bin.upperBound = attVal;
- this.binList.add(index + 1, newBin);
- }
- }
- }
+ @Override
+ public void observeAttributeClass(double attVal, int classVal, double weight) {
+ if (!Utils.isMissingValue(attVal)) {
+ if (this.binList.size() < 1) {
+ // create the first bin
+ Bin newBin = new Bin();
+ newBin.classWeights.addToValue(classVal, weight);
+ newBin.boundaryClass = classVal;
+ newBin.boundaryWeight = weight;
+ newBin.upperBound = attVal;
+ newBin.lowerBound = attVal;
+ this.binList.add(newBin);
+ } else {
+ // find bin containing new example with binary search
+ int index = 0;
+ boolean found = false;
+ int min = 0;
+ int max = this.binList.size() - 1;
+ while ((min <= max) && !found) {
+ int i = (min + max) / 2;
+ Bin bin = this.binList.get(i);
+ if (((attVal >= bin.lowerBound) && (attVal < bin.upperBound))
+ || ((i == this.binList.size() - 1)
+ && (attVal >= bin.lowerBound) && (attVal <= bin.upperBound))) {
+ found = true;
+ index = i;
+ } else if (attVal < bin.lowerBound) {
+ max = i - 1;
+ } else {
+ min = i + 1;
+ }
}
- }
-
- @Override
- public double probabilityOfAttributeValueGivenClass(double attVal,
- int classVal) {
- // TODO: NaiveBayes broken until implemented
- return 0.0;
- }
-
- @Override
- public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
- SplitCriterion criterion, double[] preSplitDist, int attIndex,
- boolean binaryOnly) {
- AttributeSplitSuggestion bestSuggestion = null;
- DoubleVector rightDist = new DoubleVector();
- for (Bin bin : this.binList) {
- rightDist.addValues(bin.classWeights);
+ boolean first = false;
+ boolean last = false;
+ if (!found) {
+ // determine if it is before or after the existing range
+ Bin bin = this.binList.get(0);
+ if (bin.lowerBound > attVal) {
+ // go before the first bin
+ index = 0;
+ first = true;
+ } else {
+ // if we haven't found it yet value must be > last bins
+ // upperBound
+ index = this.binList.size() - 1;
+ last = true;
+ }
}
- DoubleVector leftDist = new DoubleVector();
- for (Bin bin : this.binList) {
- leftDist.addValues(bin.classWeights);
- rightDist.subtractValues(bin.classWeights);
- double[][] postSplitDists = new double[][]{
- leftDist.getArrayCopy(), rightDist.getArrayCopy()};
- double merit = criterion.getMeritOfSplit(preSplitDist,
- postSplitDists);
- if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) {
- bestSuggestion = new AttributeSplitSuggestion(
- new NumericAttributeBinaryTest(attIndex,
- bin.upperBound, false), postSplitDists, merit);
- }
+ Bin bin = this.binList.get(index); // VLIndex(ct->bins, index);
+ if ((bin.lowerBound == attVal)
+ || (this.binList.size() >= this.numBinsOption.getValue())) {// Option.getValue())
+ // {//1000)
+ // {
+ // if this is the exact same boundary and class as the bin
+ // boundary or we aren't adding new bins any more then
+ // increment
+ // boundary counts
+ bin.classWeights.addToValue(classVal, weight);
+ if ((bin.boundaryClass == classVal)
+ && (bin.lowerBound == attVal)) {
+ // if it is also the same class then special case it
+ bin.boundaryWeight += weight;
+ }
+ } else {
+ // create a new bin
+ Bin newBin = new Bin();
+ newBin.classWeights.addToValue(classVal, weight);
+ newBin.boundaryWeight = weight;
+ newBin.boundaryClass = classVal;
+ newBin.upperBound = bin.upperBound;
+ newBin.lowerBound = attVal;
+
+ double percent = 0.0;
+ // estimate initial counts with a linear interpolation
+ if (!((bin.upperBound - bin.lowerBound == 0) || last || first)) {
+ percent = 1.0 - ((attVal - bin.lowerBound) / (bin.upperBound - bin.lowerBound));
+ }
+
+ // take out the boundry points, they stay with the old bin
+ bin.classWeights.addToValue(bin.boundaryClass,
+ -bin.boundaryWeight);
+ DoubleVector weightToShift = new DoubleVector(
+ bin.classWeights);
+ weightToShift.scaleValues(percent);
+ newBin.classWeights.addValues(weightToShift);
+ bin.classWeights.subtractValues(weightToShift);
+ // put the boundry examples back in
+ bin.classWeights.addToValue(bin.boundaryClass,
+ bin.boundaryWeight);
+
+ // insert the new bin in the right place
+ if (last) {
+ bin.upperBound = attVal;
+ newBin.upperBound = attVal;
+ this.binList.add(newBin);
+ } else if (first) {
+ newBin.upperBound = bin.lowerBound;
+ this.binList.add(0, newBin);
+ } else {
+ newBin.upperBound = bin.upperBound;
+ bin.upperBound = attVal;
+ this.binList.add(index + 1, newBin);
+ }
}
- return bestSuggestion;
+ }
}
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public double probabilityOfAttributeValueGivenClass(double attVal,
+ int classVal) {
+ // TODO: NaiveBayes broken until implemented
+ return 0.0;
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- // TODO Auto-generated method stub
+ @Override
+ public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(
+ SplitCriterion criterion, double[] preSplitDist, int attIndex,
+ boolean binaryOnly) {
+ AttributeSplitSuggestion bestSuggestion = null;
+ DoubleVector rightDist = new DoubleVector();
+ for (Bin bin : this.binList) {
+ rightDist.addValues(bin.classWeights);
}
+ DoubleVector leftDist = new DoubleVector();
+ for (Bin bin : this.binList) {
+ leftDist.addValues(bin.classWeights);
+ rightDist.subtractValues(bin.classWeights);
+ double[][] postSplitDists = new double[][] {
+ leftDist.getArrayCopy(), rightDist.getArrayCopy() };
+ double merit = criterion.getMeritOfSplit(preSplitDist,
+ postSplitDists);
+ if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) {
+ bestSuggestion = new AttributeSplitSuggestion(
+ new NumericAttributeBinaryTest(attIndex,
+ bin.upperBound, false), postSplitDists, merit);
+ }
+ }
+ return bestSuggestion;
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
+
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/InstanceConditionalBinaryTest.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/InstanceConditionalBinaryTest.java
index d1267a7..b273c33 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/InstanceConditionalBinaryTest.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/InstanceConditionalBinaryTest.java
@@ -21,15 +21,16 @@
*/
/**
- * Abstract binary conditional test for instances to use to split nodes in Hoeffding trees.
- *
+ * Abstract binary conditional test for instances to use to split nodes in
+ * Hoeffding trees.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public abstract class InstanceConditionalBinaryTest extends InstanceConditionalTest {
- @Override
- public int maxBranches() {
- return 2;
- }
+ @Override
+ public int maxBranches() {
+ return 2;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/InstanceConditionalTest.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/InstanceConditionalTest.java
index 4d1b955..b893e06 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/InstanceConditionalTest.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/InstanceConditionalTest.java
@@ -25,52 +25,57 @@
import com.yahoo.labs.samoa.instances.Instance;
/**
- * Abstract conditional test for instances to use to split nodes in Hoeffding trees.
- *
+ * Abstract conditional test for instances to use to split nodes in Hoeffding
+ * trees.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public abstract class InstanceConditionalTest extends AbstractMOAObject {
- /**
- * Returns the number of the branch for an instance, -1 if unknown.
- *
- * @param inst the instance to be used
- * @return the number of the branch for an instance, -1 if unknown.
- */
- public abstract int branchForInstance(Instance inst);
+ /**
+ * Returns the number of the branch for an instance, -1 if unknown.
+ *
+ * @param inst
+ * the instance to be used
+ * @return the number of the branch for an instance, -1 if unknown.
+ */
+ public abstract int branchForInstance(Instance inst);
- /**
- * Gets whether the number of the branch for an instance is known.
- *
- * @param inst
- * @return true if the number of the branch for an instance is known
- */
- public boolean resultKnownForInstance(Instance inst) {
- return branchForInstance(inst) >= 0;
- }
+ /**
+ * Gets whether the number of the branch for an instance is known.
+ *
+ * @param inst
+ * @return true if the number of the branch for an instance is known
+ */
+ public boolean resultKnownForInstance(Instance inst) {
+ return branchForInstance(inst) >= 0;
+ }
- /**
- * Gets the number of maximum branches, -1 if unknown.
- *
- * @return the number of maximum branches, -1 if unknown..
- */
- public abstract int maxBranches();
+ /**
+ * Gets the number of maximum branches, -1 if unknown.
+ *
+ * @return the number of maximum branches, -1 if unknown..
+ */
+ public abstract int maxBranches();
- /**
- * Gets the text that describes the condition of a branch. It is used to describe the branch.
- *
- * @param branch the number of the branch to describe
- * @param context the context or header of the data stream
- * @return the text that describes the condition of the branch
- */
- public abstract String describeConditionForBranch(int branch,
- InstancesHeader context);
+ /**
+ * Gets the text that describes the condition of a branch. It is used to
+ * describe the branch.
+ *
+ * @param branch
+ * the number of the branch to describe
+ * @param context
+ * the context or header of the data stream
+ * @return the text that describes the condition of the branch
+ */
+ public abstract String describeConditionForBranch(int branch,
+ InstancesHeader context);
- /**
- * Returns an array with the attributes that the test depends on.
- *
- * @return an array with the attributes that the test depends on
- */
- public abstract int[] getAttsTestDependsOn();
+ /**
+ * Returns an array with the attributes that the test depends on.
+ *
+ * @return an array with the attributes that the test depends on
+ */
+ public abstract int[] getAttsTestDependsOn();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java
index da3c717..5056737 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NominalAttributeBinaryTest.java
@@ -24,50 +24,51 @@
import com.yahoo.labs.samoa.instances.Instance;
/**
- * Nominal binary conditional test for instances to use to split nodes in Hoeffding trees.
- *
+ * Nominal binary conditional test for instances to use to split nodes in
+ * Hoeffding trees.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class NominalAttributeBinaryTest extends InstanceConditionalBinaryTest {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected int attIndex;
+ protected int attIndex;
- protected int attValue;
+ protected int attValue;
- public NominalAttributeBinaryTest(int attIndex, int attValue) {
- this.attIndex = attIndex;
- this.attValue = attValue;
+ public NominalAttributeBinaryTest(int attIndex, int attValue) {
+ this.attIndex = attIndex;
+ this.attValue = attValue;
+ }
+
+ @Override
+ public int branchForInstance(Instance inst) {
+ int instAttIndex = this.attIndex < inst.classIndex() ? this.attIndex
+ : this.attIndex + 1;
+ return inst.isMissing(instAttIndex) ? -1 : ((int) inst.value(instAttIndex) == this.attValue ? 0 : 1);
+ }
+
+ @Override
+ public String describeConditionForBranch(int branch, InstancesHeader context) {
+ if ((branch == 0) || (branch == 1)) {
+ return InstancesHeader.getAttributeNameString(context,
+ this.attIndex)
+ + (branch == 0 ? " = " : " != ")
+ + InstancesHeader.getNominalValueString(context,
+ this.attIndex, this.attValue);
}
+ throw new IndexOutOfBoundsException();
+ }
- @Override
- public int branchForInstance(Instance inst) {
- int instAttIndex = this.attIndex < inst.classIndex() ? this.attIndex
- : this.attIndex + 1;
- return inst.isMissing(instAttIndex) ? -1 : ((int) inst.value(instAttIndex) == this.attValue ? 0 : 1);
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- public String describeConditionForBranch(int branch, InstancesHeader context) {
- if ((branch == 0) || (branch == 1)) {
- return InstancesHeader.getAttributeNameString(context,
- this.attIndex)
- + (branch == 0 ? " = " : " != ")
- + InstancesHeader.getNominalValueString(context,
- this.attIndex, this.attValue);
- }
- throw new IndexOutOfBoundsException();
- }
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
-
- @Override
- public int[] getAttsTestDependsOn() {
- return new int[]{this.attIndex};
- }
+ @Override
+ public int[] getAttsTestDependsOn() {
+ return new int[] { this.attIndex };
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NominalAttributeMultiwayTest.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NominalAttributeMultiwayTest.java
index 82c91d3..5c64070 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NominalAttributeMultiwayTest.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NominalAttributeMultiwayTest.java
@@ -24,48 +24,49 @@
import com.yahoo.labs.samoa.instances.Instance;
/**
- * Nominal multi way conditional test for instances to use to split nodes in Hoeffding trees.
- *
+ * Nominal multi way conditional test for instances to use to split nodes in
+ * Hoeffding trees.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class NominalAttributeMultiwayTest extends InstanceConditionalTest {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected int attIndex;
+ protected int attIndex;
- public NominalAttributeMultiwayTest(int attIndex) {
- this.attIndex = attIndex;
- }
+ public NominalAttributeMultiwayTest(int attIndex) {
+ this.attIndex = attIndex;
+ }
- @Override
- public int branchForInstance(Instance inst) {
- int instAttIndex = this.attIndex ; //< inst.classIndex() ? this.attIndex
- //: this.attIndex + 1;
- return inst.isMissing(instAttIndex) ? -1 : (int) inst.value(instAttIndex);
- }
+ @Override
+ public int branchForInstance(Instance inst) {
+ int instAttIndex = this.attIndex; // < inst.classIndex() ? this.attIndex
+ // : this.attIndex + 1;
+ return inst.isMissing(instAttIndex) ? -1 : (int) inst.value(instAttIndex);
+ }
- @Override
- public String describeConditionForBranch(int branch, InstancesHeader context) {
- return InstancesHeader.getAttributeNameString(context, this.attIndex)
- + " = "
- + InstancesHeader.getNominalValueString(context, this.attIndex,
- branch);
- }
+ @Override
+ public String describeConditionForBranch(int branch, InstancesHeader context) {
+ return InstancesHeader.getAttributeNameString(context, this.attIndex)
+ + " = "
+ + InstancesHeader.getNominalValueString(context, this.attIndex,
+ branch);
+ }
- @Override
- public int maxBranches() {
- return -1;
- }
+ @Override
+ public int maxBranches() {
+ return -1;
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- public int[] getAttsTestDependsOn() {
- return new int[]{this.attIndex};
- }
+ @Override
+ public int[] getAttsTestDependsOn() {
+ return new int[] { this.attIndex };
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java
index 82c7395..0a05742 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/conditionaltests/NumericAttributeBinaryTest.java
@@ -24,69 +24,70 @@
import com.yahoo.labs.samoa.instances.Instance;
/**
- * Numeric binary conditional test for instances to use to split nodes in Hoeffding trees.
- *
+ * Numeric binary conditional test for instances to use to split nodes in
+ * Hoeffding trees.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class NumericAttributeBinaryTest extends InstanceConditionalBinaryTest {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected int attIndex;
+ protected int attIndex;
- protected double attValue;
+ protected double attValue;
- protected boolean equalsPassesTest;
+ protected boolean equalsPassesTest;
- public NumericAttributeBinaryTest(int attIndex, double attValue,
- boolean equalsPassesTest) {
- this.attIndex = attIndex;
- this.attValue = attValue;
- this.equalsPassesTest = equalsPassesTest;
+ public NumericAttributeBinaryTest(int attIndex, double attValue,
+ boolean equalsPassesTest) {
+ this.attIndex = attIndex;
+ this.attValue = attValue;
+ this.equalsPassesTest = equalsPassesTest;
+ }
+
+ @Override
+ public int branchForInstance(Instance inst) {
+ int instAttIndex = this.attIndex; // < inst.classIndex() ? this.attIndex
+ // : this.attIndex + 1;
+ if (inst.isMissing(instAttIndex)) {
+ return -1;
}
-
- @Override
- public int branchForInstance(Instance inst) {
- int instAttIndex = this.attIndex ; // < inst.classIndex() ? this.attIndex
- // : this.attIndex + 1;
- if (inst.isMissing(instAttIndex)) {
- return -1;
- }
- double v = inst.value(instAttIndex);
- if (v == this.attValue) {
- return this.equalsPassesTest ? 0 : 1;
- }
- return v < this.attValue ? 0 : 1;
+ double v = inst.value(instAttIndex);
+ if (v == this.attValue) {
+ return this.equalsPassesTest ? 0 : 1;
}
+ return v < this.attValue ? 0 : 1;
+ }
- @Override
- public String describeConditionForBranch(int branch, InstancesHeader context) {
- if ((branch == 0) || (branch == 1)) {
- char compareChar = branch == 0 ? '<' : '>';
- int equalsBranch = this.equalsPassesTest ? 0 : 1;
- return InstancesHeader.getAttributeNameString(context,
- this.attIndex)
- + ' '
- + compareChar
- + (branch == equalsBranch ? "= " : " ")
- + InstancesHeader.getNumericValueString(context,
- this.attIndex, this.attValue);
- }
- throw new IndexOutOfBoundsException();
+ @Override
+ public String describeConditionForBranch(int branch, InstancesHeader context) {
+ if ((branch == 0) || (branch == 1)) {
+ char compareChar = branch == 0 ? '<' : '>';
+ int equalsBranch = this.equalsPassesTest ? 0 : 1;
+ return InstancesHeader.getAttributeNameString(context,
+ this.attIndex)
+ + ' '
+ + compareChar
+ + (branch == equalsBranch ? "= " : " ")
+ + InstancesHeader.getNumericValueString(context,
+ this.attIndex, this.attValue);
}
+ throw new IndexOutOfBoundsException();
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- public int[] getAttsTestDependsOn() {
- return new int[]{this.attIndex};
- }
+ @Override
+ public int[] getAttsTestDependsOn() {
+ return new int[] { this.attIndex };
+ }
- public double getSplitValue() {
- return this.attValue;
- }
+ public double getSplitValue() {
+ return this.attValue;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ADWIN.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ADWIN.java
index b05a302..efacc93 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ADWIN.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ADWIN.java
@@ -23,582 +23,573 @@
import com.yahoo.labs.samoa.moa.AbstractMOAObject;
/**
- * ADaptive sliding WINdow method. This method is a change detector and estimator.
- * It keeps a variable-length window of recently seen
- * items, with the property that the window has the maximal length statistically
- * consistent with the hypothesis "there has been no change in the average value
- * inside the window".
- *
- *
+ * ADaptive sliding WINdow method. This method is a change detector and
+ * estimator. It keeps a variable-length window of recently seen items, with the
+ * property that the window has the maximal length statistically consistent with
+ * the hypothesis "there has been no change in the average value inside the
+ * window".
+ *
+ *
* @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
* @version $Revision: 7 $
*/
public class ADWIN extends AbstractMOAObject {
- private class List extends AbstractMOAObject {
+ private class List extends AbstractMOAObject {
- protected int count;
+ protected int count;
- protected ListItem head;
+ protected ListItem head;
- protected ListItem tail;
+ protected ListItem tail;
- public List() {
-// post: initializes the list to be empty.
- clear();
- addToHead();
- }
-
- /* Interface Store Methods */
- public int size() {
- // post: returns the number of elements in the list.
- return this.count;
- }
-
- public ListItem head() {
- // post: returns the number of elements in the list.
- return this.head;
- }
-
- public ListItem tail() {
- // post: returns the number of elements in the list.
- return this.tail;
- }
-
- public boolean isEmpty() {
- // post: returns the true iff store is empty.
- return (this.size() == 0);
- }
-
- public void clear() {
- // post: clears the list so that it contains no elements.
- this.head = null;
- this.tail = null;
- this.count = 0;
- }
-
- /* Interface List Methods */
- public void addToHead() {
- // pre: anObject is non-null
- // post: the object is added to the beginning of the list
- this.head = new ListItem(this.head, null);
- if (this.tail == null) {
- this.tail = this.head;
- }
- this.count++;
- }
-
- public void removeFromHead() {
- // pre: list is not empty
- // post: removes and returns first object from the list
-// ListItem temp;
-// temp = this.head;
- this.head = this.head.next();
- if (this.head != null) {
- this.head.setPrevious(null);
- } else {
- this.tail = null;
- }
- this.count--;
- }
-
- public void addToTail() {
-// pre: anObject is non-null
-// post: the object is added at the end of the list
- this.tail = new ListItem(null, this.tail);
- if (this.head == null) {
- this.head = this.tail;
- }
- this.count++;
- }
-
- public void removeFromTail() {
-// pre: list is not empty
-// post: the last object in the list is removed and returned
-// ListItem temp;
-// temp = this.tail;
- this.tail = this.tail.previous();
- if (this.tail == null) {
- this.head = null;
- } else {
- this.tail.setNext(null);
- }
- this.count--;
- //temp=null;
- }
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- }
+ public List() {
+ // post: initializes the list to be empty.
+ clear();
+ addToHead();
}
- private class ListItem extends AbstractMOAObject {
- protected ListItem next;
-
- protected ListItem previous;
-
- protected int bucketSizeRow = 0;
-
- protected int MAXBUCKETS = ADWIN.MAXBUCKETS;
-
- protected double bucketTotal[] = new double[MAXBUCKETS + 1];
-
- protected double bucketVariance[] = new double[MAXBUCKETS + 1];
-
- public ListItem() {
-// post: initializes the node to be a tail node
-// containing the given value.
- this(null, null);
- }
-
- public void clear() {
- bucketSizeRow = 0;
- for (int k = 0; k <= MAXBUCKETS; k++) {
- clearBucket(k);
- }
- }
-
- private void clearBucket(int k) {
- setTotal(0, k);
- setVariance(0, k);
- }
-
- public ListItem(ListItem nextNode, ListItem previousNode) {
-// post: initializes the node to contain the given
-// object and link to the given next node.
- //this.data = element;
- this.next = nextNode;
- this.previous = previousNode;
- if (nextNode != null) {
- nextNode.previous = this;
- }
- if (previousNode != null) {
- previousNode.next = this;
- }
- clear();
- }
-
- public void insertBucket(double Value, double Variance) {
-// insert a Bucket at the end
- int k = bucketSizeRow;
- bucketSizeRow++;
- //Insert new bucket
- setTotal(Value, k);
- setVariance(Variance, k);
- }
-
- public void RemoveBucket() {
-// Removes the first Buvket
- compressBucketsRow(1);
- }
-
- public void compressBucketsRow(int NumberItemsDeleted) {
- //Delete first elements
- for (int k = NumberItemsDeleted; k <= MAXBUCKETS; k++) {
- bucketTotal[k - NumberItemsDeleted] = bucketTotal[k];
- bucketVariance[k - NumberItemsDeleted] = bucketVariance[k];
- }
- for (int k = 1; k <= NumberItemsDeleted; k++) {
- clearBucket(MAXBUCKETS - k + 1);
- }
- bucketSizeRow -= NumberItemsDeleted;
- //BucketNumber-=NumberItemsDeleted;
- }
-
- public ListItem previous() {
-// post: returns the previous node.
- return this.previous;
- }
-
- public void setPrevious(ListItem previous) {
-// post: sets the previous node to be the given node
- this.previous = previous;
- }
-
- public ListItem next() {
-// post: returns the next node.
- return this.next;
- }
-
- public void setNext(ListItem next) {
-// post: sets the next node to be the given node
- this.next = next;
- }
-
- public double Total(int k) {
-// post: returns the element in this node
- return bucketTotal[k];
- }
-
- public double Variance(int k) {
-// post: returns the element in this node
- return bucketVariance[k];
- }
-
- public void setTotal(double value, int k) {
-// post: sets the element in this node to the given
-// object.
- bucketTotal[k] = value;
- }
-
- public void setVariance(double value, int k) {
-// post: sets the element in this node to the given
-// object.
- bucketVariance[k] = value;
- }
- /*
- public ListItem(Object element,
- ListItem nextNode){
- // post: initializes the node to contain the given
- // object and link to the given next node.
- this.data = element;
- this.next = nextNode;
- }
- public ListItem(Object element) {
- // post: initializes the node to be a tail node
- // containing the given value.
- this(element, null);
- }
-
-
- public Object value() {
- // post: returns the element in this node
- return this.data;
- }
- public void setValue(Object anObject) {
- // post: sets the element in this node to the given
- // object.
- this.data = anObject;
- }
- */
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- }
+ /* Interface Store Methods */
+ public int size() {
+ // post: returns the number of elements in the list.
+ return this.count;
}
- public static final double DELTA = .002; //.1;
-
- private static final int mintMinimLongitudWindow = 10; //10
-
- private double mdbldelta = .002; //.1;
-
- private int mintTime = 0;
-
- private int mintClock = 32;
-
- private double mdblWidth = 0; // Mean of Width = mdblWidth/Number of items
- //BUCKET
-
- public static final int MAXBUCKETS = 5;
-
- private int lastBucketRow = 0;
-
- private double TOTAL = 0;
-
- private double VARIANCE = 0;
-
- private int WIDTH = 0;
-
- private int BucketNumber = 0;
-
- private int Detect = 0;
-
- private int numberDetections = 0;
-
- private int DetectTwice = 0;
-
- private boolean blnBucketDeleted = false;
-
- private int BucketNumberMAX = 0;
-
- private int mintMinWinLength = 5;
-
- private List listRowBuckets;
-
- public boolean getChange() {
- return blnBucketDeleted;
+ public ListItem head() {
+ // post: returns the number of elements in the list.
+ return this.head;
}
- public void resetChange() {
- blnBucketDeleted = false;
+ public ListItem tail() {
+ // post: returns the number of elements in the list.
+ return this.tail;
}
- public int getBucketsUsed() {
- return BucketNumberMAX;
+ public boolean isEmpty() {
+ // post: returns the true iff store is empty.
+ return (this.size() == 0);
}
- public int getWidth() {
- return WIDTH;
+ public void clear() {
+ // post: clears the list so that it contains no elements.
+ this.head = null;
+ this.tail = null;
+ this.count = 0;
}
- public void setClock(int intClock) {
- mintClock = intClock;
+ /* Interface List Methods */
+ public void addToHead() {
+ // pre: anObject is non-null
+ // post: the object is added to the beginning of the list
+ this.head = new ListItem(this.head, null);
+ if (this.tail == null) {
+ this.tail = this.head;
+ }
+ this.count++;
}
- public int getClock() {
- return mintClock;
+ public void removeFromHead() {
+ // pre: list is not empty
+ // post: removes and returns first object from the list
+ // ListItem temp;
+ // temp = this.head;
+ this.head = this.head.next();
+ if (this.head != null) {
+ this.head.setPrevious(null);
+ } else {
+ this.tail = null;
+ }
+ this.count--;
}
- public boolean getWarning() {
- return false;
+ public void addToTail() {
+ // pre: anObject is non-null
+ // post: the object is added at the end of the list
+ this.tail = new ListItem(null, this.tail);
+ if (this.head == null) {
+ this.head = this.tail;
+ }
+ this.count++;
}
- public boolean getDetect() {
- return (Detect == mintTime);
- }
-
- public int getNumberDetections() {
- return numberDetections;
- }
-
- public double getTotal() {
- return TOTAL;
- }
-
- public double getEstimation() {
- return TOTAL / WIDTH;
- }
-
- public double getVariance() {
- return VARIANCE / WIDTH;
- }
-
- public double getWidthT() {
- return mdblWidth;
- }
-
- private void initBuckets() {
- //Init buckets
- listRowBuckets = new List();
- lastBucketRow = 0;
- TOTAL = 0;
- VARIANCE = 0;
- WIDTH = 0;
- BucketNumber = 0;
- }
-
- private void insertElement(double Value) {
- WIDTH++;
- insertElementBucket(0, Value, listRowBuckets.head());
- double incVariance = 0;
- if (WIDTH > 1) {
- incVariance = (WIDTH - 1) * (Value - TOTAL / (WIDTH - 1)) * (Value - TOTAL / (WIDTH - 1)) / WIDTH;
- }
- VARIANCE += incVariance;
- TOTAL += Value;
- compressBuckets();
- }
-
- private void insertElementBucket(double Variance, double Value, ListItem Node) {
- //Insert new bucket
- Node.insertBucket(Value, Variance);
- BucketNumber++;
- if (BucketNumber > BucketNumberMAX) {
- BucketNumberMAX = BucketNumber;
- }
- }
-
- private int bucketSize(int Row) {
- return (int) Math.pow(2, Row);
- }
-
- public int deleteElement() {
- //LIST
- //Update statistics
- ListItem Node;
- Node = listRowBuckets.tail();
- int n1 = bucketSize(lastBucketRow);
- WIDTH -= n1;
- TOTAL -= Node.Total(0);
- double u1 = Node.Total(0) / n1;
- double incVariance = Node.Variance(0) + n1 * WIDTH * (u1 - TOTAL / WIDTH) * (u1 - TOTAL / WIDTH) / (n1 + WIDTH);
- VARIANCE -= incVariance;
-
- //Delete Bucket
- Node.RemoveBucket();
- BucketNumber--;
- if (Node.bucketSizeRow == 0) {
- listRowBuckets.removeFromTail();
- lastBucketRow--;
- }
- return n1;
- }
-
- public void compressBuckets() {
- //Traverse the list of buckets in increasing order
- int n1, n2;
- double u2, u1, incVariance;
- ListItem cursor;
- ListItem nextNode;
- cursor = listRowBuckets.head();
- int i = 0;
- do {
- //Find the number of buckets in a row
- int k = cursor.bucketSizeRow;
- //If the row is full, merge buckets
- if (k == MAXBUCKETS + 1) {
- nextNode = cursor.next();
- if (nextNode == null) {
- listRowBuckets.addToTail();
- nextNode = cursor.next();
- lastBucketRow++;
- }
- n1 = bucketSize(i);
- n2 = bucketSize(i);
- u1 = cursor.Total(0) / n1;
- u2 = cursor.Total(1) / n2;
- incVariance = n1 * n2 * (u1 - u2) * (u1 - u2) / (n1 + n2);
-
- nextNode.insertBucket(cursor.Total(0) + cursor.Total(1), cursor.Variance(0) + cursor.Variance(1) + incVariance);
- BucketNumber++;
- cursor.compressBucketsRow(2);
- if (nextNode.bucketSizeRow <= MAXBUCKETS) {
- break;
- }
- } else {
- break;
- }
- cursor = cursor.next();
- i++;
- } while (cursor != null);
- }
-
- public boolean setInput(double intEntrada) {
- return setInput(intEntrada, mdbldelta);
- }
-
- public boolean setInput(double intEntrada, double delta) {
- boolean blnChange = false;
- boolean blnExit;
- ListItem cursor;
- mintTime++;
-
- //1,2)Increment window in one element
- insertElement(intEntrada);
- blnBucketDeleted = false;
- //3)Reduce window
- if (mintTime % mintClock == 0 && getWidth() > mintMinimLongitudWindow) {
- boolean blnReduceWidth = true; // Diference
-
- while (blnReduceWidth) // Diference
- {
- blnReduceWidth = false; // Diference
- blnExit = false;
- int n0 = 0;
- int n1 = WIDTH;
- double u0 = 0;
- double u1 = getTotal();
- double v0 = 0;
- double v1 = VARIANCE;
- double n2;
- double u2;
-
- cursor = listRowBuckets.tail();
- int i = lastBucketRow;
- do {
- for (int k = 0; k <= (cursor.bucketSizeRow - 1); k++) {
- n2 = bucketSize(i);
- u2 = cursor.Total(k);
- if (n0 > 0) {
- v0 += cursor.Variance(k) + (double) n0 * n2 * (u0 / n0 - u2 / n2) * (u0 / n0 - u2 / n2) / (n0 + n2);
- }
- if (n1 > 0) {
- v1 -= cursor.Variance(k) + (double) n1 * n2 * (u1 / n1 - u2 / n2) * (u1 / n1 - u2 / n2) / (n1 + n2);
- }
-
- n0 += bucketSize(i);
- n1 -= bucketSize(i);
- u0 += cursor.Total(k);
- u1 -= cursor.Total(k);
-
- if (i == 0 && k == cursor.bucketSizeRow - 1) {
- blnExit = true;
- break;
- }
- double absvalue = (u0 / n0) - (u1 / n1); //n1<WIDTH-mintMinWinLength-1
- if ((n1 > mintMinWinLength + 1 && n0 > mintMinWinLength + 1) && // Diference NEGATIVE
- //if(
- blnCutexpression(n0, n1, u0, u1, v0, v1, absvalue, delta)) {
- blnBucketDeleted = true;
- Detect = mintTime;
-
- if (Detect == 0) {
- Detect = mintTime;
- //blnFirst=true;
- //blnWarning=true;
- } else if (DetectTwice == 0) {
- DetectTwice = mintTime;
- //blnDetect=true;
- }
- blnReduceWidth = true; // Diference
- blnChange = true;
- if (getWidth() > 0) { //Reduce width of the window
- //while (n0>0) // Diference NEGATIVE
- n0 -= deleteElement();
- blnExit = true;
- break;
- }
- } //End if
- }//Next k
- cursor = cursor.previous();
- i--;
- } while (((!blnExit && cursor != null)));
- }//End While // Diference
- }//End if
-
- mdblWidth += getWidth();
- if (blnChange) {
- numberDetections++;
- }
- return blnChange;
- }
-
- private boolean blnCutexpression(int n0, int n1, double u0, double u1, double v0, double v1, double absvalue, double delta) {
- int n = getWidth();
- double dd = Math.log(2 * Math.log(n) / delta); // -- ull perque el ln n va al numerador.
- // Formula Gener 2008
- double v = getVariance();
- double m = ((double) 1 / ((n0 - mintMinWinLength + 1))) + ((double) 1 / ((n1 - mintMinWinLength + 1)));
- double epsilon = Math.sqrt(2 * m * v * dd) + (double) 2 / 3 * dd * m;
-
- return (Math.abs(absvalue) > epsilon);
- }
-
- public ADWIN() {
- mdbldelta = DELTA;
- initBuckets();
- Detect = 0;
- numberDetections = 0;
- DetectTwice = 0;
-
- }
-
- public ADWIN(double d) {
- mdbldelta = d;
- initBuckets();
- Detect = 0;
- numberDetections = 0;
- DetectTwice = 0;
- }
-
- public ADWIN(int cl) {
- mdbldelta = DELTA;
- initBuckets();
- Detect = 0;
- numberDetections = 0;
- DetectTwice = 0;
- mintClock = cl;
- }
-
- public String getEstimatorInfo() {
- return "ADWIN;;";
- }
-
- public void setW(int W0) {
+ public void removeFromTail() {
+ // pre: list is not empty
+ // post: the last object in the list is removed and returned
+ // ListItem temp;
+ // temp = this.tail;
+ this.tail = this.tail.previous();
+ if (this.tail == null) {
+ this.head = null;
+ } else {
+ this.tail.setNext(null);
+ }
+ this.count--;
+ // temp=null;
}
@Override
public void getDescription(StringBuilder sb, int indent) {
}
+ }
+
+ private class ListItem extends AbstractMOAObject {
+ protected ListItem next;
+
+ protected ListItem previous;
+
+ protected int bucketSizeRow = 0;
+
+ protected int MAXBUCKETS = ADWIN.MAXBUCKETS;
+
+ protected double bucketTotal[] = new double[MAXBUCKETS + 1];
+
+ protected double bucketVariance[] = new double[MAXBUCKETS + 1];
+
+ public ListItem() {
+ // post: initializes the node to be a tail node
+ // containing the given value.
+ this(null, null);
+ }
+
+ public void clear() {
+ bucketSizeRow = 0;
+ for (int k = 0; k <= MAXBUCKETS; k++) {
+ clearBucket(k);
+ }
+ }
+
+ private void clearBucket(int k) {
+ setTotal(0, k);
+ setVariance(0, k);
+ }
+
+ public ListItem(ListItem nextNode, ListItem previousNode) {
+ // post: initializes the node to contain the given
+ // object and link to the given next node.
+ // this.data = element;
+ this.next = nextNode;
+ this.previous = previousNode;
+ if (nextNode != null) {
+ nextNode.previous = this;
+ }
+ if (previousNode != null) {
+ previousNode.next = this;
+ }
+ clear();
+ }
+
+ public void insertBucket(double Value, double Variance) {
+ // insert a Bucket at the end
+ int k = bucketSizeRow;
+ bucketSizeRow++;
+ // Insert new bucket
+ setTotal(Value, k);
+ setVariance(Variance, k);
+ }
+
+ public void RemoveBucket() {
+ // Removes the first Buvket
+ compressBucketsRow(1);
+ }
+
+ public void compressBucketsRow(int NumberItemsDeleted) {
+ // Delete first elements
+ for (int k = NumberItemsDeleted; k <= MAXBUCKETS; k++) {
+ bucketTotal[k - NumberItemsDeleted] = bucketTotal[k];
+ bucketVariance[k - NumberItemsDeleted] = bucketVariance[k];
+ }
+ for (int k = 1; k <= NumberItemsDeleted; k++) {
+ clearBucket(MAXBUCKETS - k + 1);
+ }
+ bucketSizeRow -= NumberItemsDeleted;
+ // BucketNumber-=NumberItemsDeleted;
+ }
+
+ public ListItem previous() {
+ // post: returns the previous node.
+ return this.previous;
+ }
+
+ public void setPrevious(ListItem previous) {
+ // post: sets the previous node to be the given node
+ this.previous = previous;
+ }
+
+ public ListItem next() {
+ // post: returns the next node.
+ return this.next;
+ }
+
+ public void setNext(ListItem next) {
+ // post: sets the next node to be the given node
+ this.next = next;
+ }
+
+ public double Total(int k) {
+ // post: returns the element in this node
+ return bucketTotal[k];
+ }
+
+ public double Variance(int k) {
+ // post: returns the element in this node
+ return bucketVariance[k];
+ }
+
+ public void setTotal(double value, int k) {
+ // post: sets the element in this node to the given
+ // object.
+ bucketTotal[k] = value;
+ }
+
+ public void setVariance(double value, int k) {
+ // post: sets the element in this node to the given
+ // object.
+ bucketVariance[k] = value;
+ }
+
+ /*
+ * public ListItem(Object element, ListItem nextNode){ // post: initializes
+ * the node to contain the given // object and link to the given next node.
+ * this.data = element; this.next = nextNode; } public ListItem(Object
+ * element) { // post: initializes the node to be a tail node // containing
+ * the given value. this(element, null); }
+ *
+ *
+ * public Object value() { // post: returns the element in this node return
+ * this.data; } public void setValue(Object anObject) { // post: sets the
+ * element in this node to the given // object. this.data = anObject; }
+ */
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ }
+ }
+
+ public static final double DELTA = .002; // .1;
+
+ private static final int mintMinimLongitudWindow = 10; // 10
+
+ private double mdbldelta = .002; // .1;
+
+ private int mintTime = 0;
+
+ private int mintClock = 32;
+
+ private double mdblWidth = 0; // Mean of Width = mdblWidth/Number of items
+ // BUCKET
+
+ public static final int MAXBUCKETS = 5;
+
+ private int lastBucketRow = 0;
+
+ private double TOTAL = 0;
+
+ private double VARIANCE = 0;
+
+ private int WIDTH = 0;
+
+ private int BucketNumber = 0;
+
+ private int Detect = 0;
+
+ private int numberDetections = 0;
+
+ private int DetectTwice = 0;
+
+ private boolean blnBucketDeleted = false;
+
+ private int BucketNumberMAX = 0;
+
+ private int mintMinWinLength = 5;
+
+ private List listRowBuckets;
+
+ public boolean getChange() {
+ return blnBucketDeleted;
+ }
+
+ public void resetChange() {
+ blnBucketDeleted = false;
+ }
+
+ public int getBucketsUsed() {
+ return BucketNumberMAX;
+ }
+
+ public int getWidth() {
+ return WIDTH;
+ }
+
+ public void setClock(int intClock) {
+ mintClock = intClock;
+ }
+
+ public int getClock() {
+ return mintClock;
+ }
+
+ public boolean getWarning() {
+ return false;
+ }
+
+ public boolean getDetect() {
+ return (Detect == mintTime);
+ }
+
+ public int getNumberDetections() {
+ return numberDetections;
+ }
+
+ public double getTotal() {
+ return TOTAL;
+ }
+
+ public double getEstimation() {
+ return TOTAL / WIDTH;
+ }
+
+ public double getVariance() {
+ return VARIANCE / WIDTH;
+ }
+
+ public double getWidthT() {
+ return mdblWidth;
+ }
+
+ private void initBuckets() {
+ // Init buckets
+ listRowBuckets = new List();
+ lastBucketRow = 0;
+ TOTAL = 0;
+ VARIANCE = 0;
+ WIDTH = 0;
+ BucketNumber = 0;
+ }
+
+ private void insertElement(double Value) {
+ WIDTH++;
+ insertElementBucket(0, Value, listRowBuckets.head());
+ double incVariance = 0;
+ if (WIDTH > 1) {
+ incVariance = (WIDTH - 1) * (Value - TOTAL / (WIDTH - 1)) * (Value - TOTAL / (WIDTH - 1)) / WIDTH;
+ }
+ VARIANCE += incVariance;
+ TOTAL += Value;
+ compressBuckets();
+ }
+
+ private void insertElementBucket(double Variance, double Value, ListItem Node) {
+ // Insert new bucket
+ Node.insertBucket(Value, Variance);
+ BucketNumber++;
+ if (BucketNumber > BucketNumberMAX) {
+ BucketNumberMAX = BucketNumber;
+ }
+ }
+
+ private int bucketSize(int Row) {
+ return (int) Math.pow(2, Row);
+ }
+
+ public int deleteElement() {
+ // LIST
+ // Update statistics
+ ListItem Node;
+ Node = listRowBuckets.tail();
+ int n1 = bucketSize(lastBucketRow);
+ WIDTH -= n1;
+ TOTAL -= Node.Total(0);
+ double u1 = Node.Total(0) / n1;
+ double incVariance = Node.Variance(0) + n1 * WIDTH * (u1 - TOTAL / WIDTH) * (u1 - TOTAL / WIDTH) / (n1 + WIDTH);
+ VARIANCE -= incVariance;
+
+ // Delete Bucket
+ Node.RemoveBucket();
+ BucketNumber--;
+ if (Node.bucketSizeRow == 0) {
+ listRowBuckets.removeFromTail();
+ lastBucketRow--;
+ }
+ return n1;
+ }
+
+ public void compressBuckets() {
+ // Traverse the list of buckets in increasing order
+ int n1, n2;
+ double u2, u1, incVariance;
+ ListItem cursor;
+ ListItem nextNode;
+ cursor = listRowBuckets.head();
+ int i = 0;
+ do {
+ // Find the number of buckets in a row
+ int k = cursor.bucketSizeRow;
+ // If the row is full, merge buckets
+ if (k == MAXBUCKETS + 1) {
+ nextNode = cursor.next();
+ if (nextNode == null) {
+ listRowBuckets.addToTail();
+ nextNode = cursor.next();
+ lastBucketRow++;
+ }
+ n1 = bucketSize(i);
+ n2 = bucketSize(i);
+ u1 = cursor.Total(0) / n1;
+ u2 = cursor.Total(1) / n2;
+ incVariance = n1 * n2 * (u1 - u2) * (u1 - u2) / (n1 + n2);
+
+ nextNode.insertBucket(cursor.Total(0) + cursor.Total(1), cursor.Variance(0) + cursor.Variance(1) + incVariance);
+ BucketNumber++;
+ cursor.compressBucketsRow(2);
+ if (nextNode.bucketSizeRow <= MAXBUCKETS) {
+ break;
+ }
+ } else {
+ break;
+ }
+ cursor = cursor.next();
+ i++;
+ } while (cursor != null);
+ }
+
+ public boolean setInput(double intEntrada) {
+ return setInput(intEntrada, mdbldelta);
+ }
+
+ public boolean setInput(double intEntrada, double delta) {
+ boolean blnChange = false;
+ boolean blnExit;
+ ListItem cursor;
+ mintTime++;
+
+ // 1,2)Increment window in one element
+ insertElement(intEntrada);
+ blnBucketDeleted = false;
+ // 3)Reduce window
+ if (mintTime % mintClock == 0 && getWidth() > mintMinimLongitudWindow) {
+ boolean blnReduceWidth = true; // Diference
+
+ while (blnReduceWidth) // Diference
+ {
+ blnReduceWidth = false; // Diference
+ blnExit = false;
+ int n0 = 0;
+ int n1 = WIDTH;
+ double u0 = 0;
+ double u1 = getTotal();
+ double v0 = 0;
+ double v1 = VARIANCE;
+ double n2;
+ double u2;
+
+ cursor = listRowBuckets.tail();
+ int i = lastBucketRow;
+ do {
+ for (int k = 0; k <= (cursor.bucketSizeRow - 1); k++) {
+ n2 = bucketSize(i);
+ u2 = cursor.Total(k);
+ if (n0 > 0) {
+ v0 += cursor.Variance(k) + (double) n0 * n2 * (u0 / n0 - u2 / n2) * (u0 / n0 - u2 / n2) / (n0 + n2);
+ }
+ if (n1 > 0) {
+ v1 -= cursor.Variance(k) + (double) n1 * n2 * (u1 / n1 - u2 / n2) * (u1 / n1 - u2 / n2) / (n1 + n2);
+ }
+
+ n0 += bucketSize(i);
+ n1 -= bucketSize(i);
+ u0 += cursor.Total(k);
+ u1 -= cursor.Total(k);
+
+ if (i == 0 && k == cursor.bucketSizeRow - 1) {
+ blnExit = true;
+ break;
+ }
+ double absvalue = (u0 / n0) - (u1 / n1); // n1<WIDTH-mintMinWinLength-1
+ if ((n1 > mintMinWinLength + 1 && n0 > mintMinWinLength + 1) && // Diference
+ // NEGATIVE
+ // if(
+ blnCutexpression(n0, n1, u0, u1, v0, v1, absvalue, delta)) {
+ blnBucketDeleted = true;
+ Detect = mintTime;
+
+ if (Detect == 0) {
+ Detect = mintTime;
+ // blnFirst=true;
+ // blnWarning=true;
+ } else if (DetectTwice == 0) {
+ DetectTwice = mintTime;
+ // blnDetect=true;
+ }
+ blnReduceWidth = true; // Diference
+ blnChange = true;
+ if (getWidth() > 0) { // Reduce width of the window
+ // while (n0>0) // Diference NEGATIVE
+ n0 -= deleteElement();
+ blnExit = true;
+ break;
+ }
+ } // End if
+ }// Next k
+ cursor = cursor.previous();
+ i--;
+ } while (((!blnExit && cursor != null)));
+ }// End While // Diference
+ }// End if
+
+ mdblWidth += getWidth();
+ if (blnChange) {
+ numberDetections++;
+ }
+ return blnChange;
+ }
+
+ private boolean blnCutexpression(int n0, int n1, double u0, double u1, double v0, double v1, double absvalue,
+ double delta) {
+ int n = getWidth();
+ double dd = Math.log(2 * Math.log(n) / delta); // -- ull perque el ln n va
+ // al numerador.
+ // Formula Gener 2008
+ double v = getVariance();
+ double m = ((double) 1 / ((n0 - mintMinWinLength + 1))) + ((double) 1 / ((n1 - mintMinWinLength + 1)));
+ double epsilon = Math.sqrt(2 * m * v * dd) + (double) 2 / 3 * dd * m;
+
+ return (Math.abs(absvalue) > epsilon);
+ }
+
+ public ADWIN() {
+ mdbldelta = DELTA;
+ initBuckets();
+ Detect = 0;
+ numberDetections = 0;
+ DetectTwice = 0;
+
+ }
+
+ public ADWIN(double d) {
+ mdbldelta = d;
+ initBuckets();
+ Detect = 0;
+ numberDetections = 0;
+ DetectTwice = 0;
+ }
+
+ public ADWIN(int cl) {
+ mdbldelta = DELTA;
+ initBuckets();
+ Detect = 0;
+ numberDetections = 0;
+ DetectTwice = 0;
+ mintClock = cl;
+ }
+
+ public String getEstimatorInfo() {
+ return "ADWIN;;";
+ }
+
+ public void setW(int W0) {
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ADWINChangeDetector.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ADWINChangeDetector.java
index b582351..4e7b3dc 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ADWINChangeDetector.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ADWINChangeDetector.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.classifiers.core.driftdetection;
/*
@@ -25,49 +24,48 @@
import com.yahoo.labs.samoa.moa.core.ObjectRepository;
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
-
/**
* Drift detection method based in ADWIN. ADaptive sliding WINdow is a change
* detector and estimator. It keeps a variable-length window of recently seen
* items, with the property that the window has the maximal length statistically
* consistent with the hypothesis "there has been no change in the average value
* inside the window".
- *
- *
+ *
+ *
* @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
* @version $Revision: 7 $
*/
public class ADWINChangeDetector extends AbstractChangeDetector {
- protected ADWIN adwin;
+ protected ADWIN adwin;
- public FloatOption deltaAdwinOption = new FloatOption("deltaAdwin", 'a',
- "Delta of Adwin change detection", 0.002, 0.0, 1.0);
+ public FloatOption deltaAdwinOption = new FloatOption("deltaAdwin", 'a',
+ "Delta of Adwin change detection", 0.002, 0.0, 1.0);
- @Override
- public void input(double inputValue) {
- if (this.adwin == null) {
- resetLearning();
- }
- this.isChangeDetected = adwin.setInput(inputValue);
- this.isWarningZone = false;
- this.delay = 0.0;
- this.estimation = adwin.getEstimation();
+ @Override
+ public void input(double inputValue) {
+ if (this.adwin == null) {
+ resetLearning();
}
+ this.isChangeDetected = adwin.setInput(inputValue);
+ this.isWarningZone = false;
+ this.delay = 0.0;
+ this.estimation = adwin.getEstimation();
+ }
- @Override
- public void resetLearning() {
- adwin = new ADWIN(this.deltaAdwinOption.getValue());
- }
+ @Override
+ public void resetLearning() {
+ adwin = new ADWIN(this.deltaAdwinOption.getValue());
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/AbstractChangeDetector.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/AbstractChangeDetector.java
index ff591ab..a06707c 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/AbstractChangeDetector.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/AbstractChangeDetector.java
@@ -24,120 +24,122 @@
/**
* Abstract Change Detector. All change detectors in MOA extend this class.
- *
+ *
* @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
* @version $Revision: 7 $
*/
public abstract class AbstractChangeDetector extends AbstractOptionHandler
- implements ChangeDetector {
+ implements ChangeDetector {
+ /**
+ * Change was detected
+ */
+ protected boolean isChangeDetected;
+ /**
+ * Warning Zone: after a warning and before a change
+ */
+ protected boolean isWarningZone;
- /**
- * Change was detected
- */
- protected boolean isChangeDetected;
+ /**
+ * Prediction for the next value based in previous seen values
+ */
+ protected double estimation;
- /**
- * Warning Zone: after a warning and before a change
- */
- protected boolean isWarningZone;
+ /**
+ * Delay in detecting change
+ */
+ protected double delay;
- /**
- * Prediction for the next value based in previous seen values
- */
- protected double estimation;
+ /**
+ * Resets this change detector. It must be similar to starting a new change
+ * detector from scratch.
+ *
+ */
+ public void resetLearning() {
+ this.isChangeDetected = false;
+ this.isWarningZone = false;
+ this.estimation = 0.0;
+ this.delay = 0.0;
+ }
- /**
- * Delay in detecting change
- */
- protected double delay;
+ /**
+ * Adding a numeric value to the change detector<br>
+ * <br>
+ *
+ * The output of the change detector is modified after the insertion of a new
+ * item inside.
+ *
+ * @param inputValue
+ * the number to insert into the change detector
+ */
+ public abstract void input(double inputValue);
- /**
- * Resets this change detector. It must be similar to starting a new change
- * detector from scratch.
- *
- */
- public void resetLearning() {
- this.isChangeDetected = false;
- this.isWarningZone = false;
- this.estimation = 0.0;
- this.delay = 0.0;
- }
+ /**
+ * Gets whether there is change detected.
+ *
+ * @return true if there is change
+ */
+ public boolean getChange() {
+ return this.isChangeDetected;
+ }
- /**
- * Adding a numeric value to the change detector<br><br>
- *
- * The output of the change detector is modified after the insertion of a
- * new item inside.
- *
- * @param inputValue the number to insert into the change detector
- */
- public abstract void input(double inputValue);
+ /**
+ * Gets whether the change detector is in the warning zone, after a warning
+ * alert and before a change alert.
+ *
+ * @return true if the change detector is in the warning zone
+ */
+ public boolean getWarningZone() {
+ return this.isWarningZone;
+ }
- /**
- * Gets whether there is change detected.
- *
- * @return true if there is change
- */
- public boolean getChange() {
- return this.isChangeDetected;
- }
+ /**
+ * Gets the prediction of next values.
+ *
+ * @return a prediction of the next value
+ */
+ public double getEstimation() {
+ return this.estimation;
+ }
- /**
- * Gets whether the change detector is in the warning zone, after a warning
- * alert and before a change alert.
- *
- * @return true if the change detector is in the warning zone
- */
- public boolean getWarningZone() {
- return this.isWarningZone;
- }
+ /**
+ * Gets the length of the delay in the change detected.
+ *
+ * @return he length of the delay in the change detected
+ */
+ public double getDelay() {
+ return this.delay;
+ }
- /**
- * Gets the prediction of next values.
- *
- * @return a prediction of the next value
- */
- public double getEstimation() {
- return this.estimation;
- }
+ /**
+ * Gets the output state of the change detection.
+ *
+ * @return an array with the number of change detections, number of warnings,
+ * delay, and estimation.
+ */
+ public double[] getOutput() {
+ return new double[] { this.isChangeDetected ? 1 : 0, this.isWarningZone ? 1 : 0, this.delay, this.estimation };
+ }
- /**
- * Gets the length of the delay in the change detected.
- *
- * @return he length of the delay in the change detected
- */
- public double getDelay() {
- return this.delay;
- }
+ /**
+ * Returns a string representation of the model.
+ *
+ * @param sb
+ * the stringbuilder to add the description
+ * @param indent
+ * the number of characters to indent
+ */
+ @Override
+ public abstract void getDescription(StringBuilder sb, int indent);
- /**
- * Gets the output state of the change detection.
- *
- * @return an array with the number of change detections, number of
- * warnings, delay, and estimation.
- */
- public double[] getOutput() {
- return new double[]{this.isChangeDetected ? 1 : 0, this.isWarningZone ? 1 : 0, this.delay, this.estimation};
- }
-
- /**
- * Returns a string representation of the model.
- *
- * @param sb the stringbuilder to add the description
- * @param indent the number of characters to indent
- */
- @Override
- public abstract void getDescription(StringBuilder sb, int indent);
-
- /**
- * Produces a copy of this change detector method
- *
- * @return the copy of this change detector method
- */
- @Override
- public ChangeDetector copy() {
- return (ChangeDetector) super.copy();
- }
+ /**
+ * Produces a copy of this change detector method
+ *
+ * @return the copy of this change detector method
+ */
+ @Override
+ public ChangeDetector copy() {
+ return (ChangeDetector) super.copy();
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ChangeDetector.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ChangeDetector.java
index 47dc203..d2bae2b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ChangeDetector.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/ChangeDetector.java
@@ -23,80 +23,85 @@
import com.yahoo.labs.samoa.moa.options.OptionHandler;
/**
- * Change Detector interface to implement methods that detects change.
- *
+ * Change Detector interface to implement methods that detects change.
+ *
* @author Albert Bifet (abifet at cs dot waikato dot ac dot nz)
* @version $Revision: 7 $
*/
public interface ChangeDetector extends OptionHandler {
- /**
- * Resets this change detector. It must be similar to starting a new change
- * detector from scratch.
- *
- */
- public void resetLearning();
+ /**
+ * Resets this change detector. It must be similar to starting a new change
+ * detector from scratch.
+ *
+ */
+ public void resetLearning();
- /**
- * Adding a numeric value to the change detector<br><br>
- *
- * The output of the change detector is modified after the insertion of a
- * new item inside.
- *
- * @param inputValue the number to insert into the change detector
- */
- public void input(double inputValue);
+ /**
+ * Adding a numeric value to the change detector<br>
+ * <br>
+ *
+ * The output of the change detector is modified after the insertion of a new
+ * item inside.
+ *
+ * @param inputValue
+ * the number to insert into the change detector
+ */
+ public void input(double inputValue);
- /**
- * Gets whether there is change detected.
- *
- * @return true if there is change
- */
- public boolean getChange();
+ /**
+ * Gets whether there is change detected.
+ *
+ * @return true if there is change
+ */
+ public boolean getChange();
- /**
- * Gets whether the change detector is in the warning zone, after a warning alert and before a change alert.
- *
- * @return true if the change detector is in the warning zone
- */
- public boolean getWarningZone();
+ /**
+ * Gets whether the change detector is in the warning zone, after a warning
+ * alert and before a change alert.
+ *
+ * @return true if the change detector is in the warning zone
+ */
+ public boolean getWarningZone();
- /**
- * Gets the prediction of next values.
- *
- * @return a prediction of the next value
- */
- public double getEstimation();
+ /**
+ * Gets the prediction of next values.
+ *
+ * @return a prediction of the next value
+ */
+ public double getEstimation();
- /**
- * Gets the length of the delay in the change detected.
- *
- * @return he length of the delay in the change detected
- */
- public double getDelay();
+ /**
+ * Gets the length of the delay in the change detected.
+ *
+ * @return he length of the delay in the change detected
+ */
+ public double getDelay();
- /**
- * Gets the output state of the change detection.
- *
- * @return an array with the number of change detections, number of
- * warnings, delay, and estimation.
- */
- public double[] getOutput();
+ /**
+ * Gets the output state of the change detection.
+ *
+ * @return an array with the number of change detections, number of warnings,
+ * delay, and estimation.
+ */
+ public double[] getOutput();
- /**
- * Returns a string representation of the model.
- *
- * @param out the stringbuilder to add the description
- * @param indent the number of characters to indent
- */
- @Override
- public void getDescription(StringBuilder sb, int indent);
+ /**
+ * Returns a string representation of the model.
+ *
+ * @param out
+ * the stringbuilder to add the description
+ * @param indent
+ * the number of characters to indent
+ */
+ @Override
+ public void getDescription(StringBuilder sb, int indent);
- /**
- * Produces a copy of this drift detection method
- *
- * @return the copy of this drift detection method
- */
- @Override
- public ChangeDetector copy();
+ /**
+ * Produces a copy of this drift detection method
+ *
+ * @return the copy of this drift detection method
+ */
+ @Override
+ public ChangeDetector copy();
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/CusumDM.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/CusumDM.java
index e372ae8..88bce73 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/CusumDM.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/CusumDM.java
@@ -27,84 +27,84 @@
/**
* Drift detection method based in Cusum
- *
- *
+ *
+ *
* @author Manuel Baena (mbaena@lcc.uma.es)
* @version $Revision: 7 $
*/
public class CusumDM extends AbstractChangeDetector {
- private static final long serialVersionUID = -3518369648142099719L;
+ private static final long serialVersionUID = -3518369648142099719L;
- public IntOption minNumInstancesOption = new IntOption(
- "minNumInstances",
- 'n',
- "The minimum number of instances before permitting detecting change.",
- 30, 0, Integer.MAX_VALUE);
+ public IntOption minNumInstancesOption = new IntOption(
+ "minNumInstances",
+ 'n',
+ "The minimum number of instances before permitting detecting change.",
+ 30, 0, Integer.MAX_VALUE);
- public FloatOption deltaOption = new FloatOption("delta", 'd',
- "Delta parameter of the Cusum Test", 0.005, 0.0, 1.0);
+ public FloatOption deltaOption = new FloatOption("delta", 'd',
+ "Delta parameter of the Cusum Test", 0.005, 0.0, 1.0);
- public FloatOption lambdaOption = new FloatOption("lambda", 'l',
- "Threshold parameter of the Cusum Test", 50, 0.0, Float.MAX_VALUE);
+ public FloatOption lambdaOption = new FloatOption("lambda", 'l',
+ "Threshold parameter of the Cusum Test", 50, 0.0, Float.MAX_VALUE);
- private int m_n;
+ private int m_n;
- private double sum;
+ private double sum;
- private double x_mean;
+ private double x_mean;
- private double delta;
+ private double delta;
- private double lambda;
+ private double lambda;
- public CusumDM() {
- resetLearning();
+ public CusumDM() {
+ resetLearning();
+ }
+
+ @Override
+ public void resetLearning() {
+ m_n = 1;
+ x_mean = 0.0;
+ sum = 0.0;
+ delta = this.deltaOption.getValue();
+ lambda = this.lambdaOption.getValue();
+ }
+
+ @Override
+ public void input(double x) {
+ // It monitors the error rate
+ if (this.isChangeDetected) {
+ resetLearning();
}
- @Override
- public void resetLearning() {
- m_n = 1;
- x_mean = 0.0;
- sum = 0.0;
- delta = this.deltaOption.getValue();
- lambda = this.lambdaOption.getValue();
+ x_mean = x_mean + (x - x_mean) / (double) m_n;
+ sum = Math.max(0, sum + x - x_mean - this.delta);
+ m_n++;
+
+ // System.out.print(prediction + " " + m_n + " " + (m_p+m_s) + " ");
+ this.estimation = x_mean;
+ this.isChangeDetected = false;
+ this.isWarningZone = false;
+ this.delay = 0;
+
+ if (m_n < this.minNumInstancesOption.getValue()) {
+ return;
}
- @Override
- public void input(double x) {
- // It monitors the error rate
- if (this.isChangeDetected) {
- resetLearning();
- }
-
- x_mean = x_mean + (x - x_mean) / (double) m_n;
- sum = Math.max(0, sum + x - x_mean - this.delta);
- m_n++;
-
- // System.out.print(prediction + " " + m_n + " " + (m_p+m_s) + " ");
- this.estimation = x_mean;
- this.isChangeDetected = false;
- this.isWarningZone = false;
- this.delay = 0;
-
- if (m_n < this.minNumInstancesOption.getValue()) {
- return;
- }
-
- if (sum > this.lambda) {
- this.isChangeDetected = true;
- }
+ if (sum > this.lambda) {
+ this.isChangeDetected = true;
}
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/DDM.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/DDM.java
index af5dd52..1bf2cdd 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/DDM.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/DDM.java
@@ -25,93 +25,95 @@
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
/**
- * Drift detection method based in DDM method of Joao Gama SBIA 2004.
- *
- * <p>João Gama, Pedro Medas, Gladys Castillo, Pedro Pereira Rodrigues: Learning
- * with Drift Detection. SBIA 2004: 286-295 </p>
- *
- * @author Manuel Baena (mbaena@lcc.uma.es)
- * @version $Revision: 7 $
+ * Drift detection method based in DDM method of Joao Gama SBIA 2004.
+ *
+ * <p>
+ * João Gama, Pedro Medas, Gladys Castillo, Pedro Pereira Rodrigues: Learning
+ * with Drift Detection. SBIA 2004: 286-295
+ * </p>
+ *
+ * @author Manuel Baena (mbaena@lcc.uma.es)
+ * @version $Revision: 7 $
*/
public class DDM extends AbstractChangeDetector {
- private static final long serialVersionUID = -3518369648142099719L;
+ private static final long serialVersionUID = -3518369648142099719L;
- //private static final int DDM_MINNUMINST = 30;
- public IntOption minNumInstancesOption = new IntOption(
- "minNumInstances",
- 'n',
- "The minimum number of instances before permitting detecting change.",
- 30, 0, Integer.MAX_VALUE);
- private int m_n;
+ // private static final int DDM_MINNUMINST = 30;
+ public IntOption minNumInstancesOption = new IntOption(
+ "minNumInstances",
+ 'n',
+ "The minimum number of instances before permitting detecting change.",
+ 30, 0, Integer.MAX_VALUE);
+ private int m_n;
- private double m_p;
+ private double m_p;
- private double m_s;
+ private double m_s;
- private double m_psmin;
+ private double m_psmin;
- private double m_pmin;
+ private double m_pmin;
- private double m_smin;
+ private double m_smin;
- public DDM() {
- resetLearning();
+ public DDM() {
+ resetLearning();
+ }
+
+ @Override
+ public void resetLearning() {
+ m_n = 1;
+ m_p = 1;
+ m_s = 0;
+ m_psmin = Double.MAX_VALUE;
+ m_pmin = Double.MAX_VALUE;
+ m_smin = Double.MAX_VALUE;
+ }
+
+ @Override
+ public void input(double prediction) {
+ // prediction must be 1 or 0
+ // It monitors the error rate
+ if (this.isChangeDetected) {
+ resetLearning();
+ }
+ m_p = m_p + (prediction - m_p) / (double) m_n;
+ m_s = Math.sqrt(m_p * (1 - m_p) / (double) m_n);
+
+ m_n++;
+
+ // System.out.print(prediction + " " + m_n + " " + (m_p+m_s) + " ");
+ this.estimation = m_p;
+ this.isChangeDetected = false;
+ this.isWarningZone = false;
+ this.delay = 0;
+
+ if (m_n < this.minNumInstancesOption.getValue()) {
+ return;
}
- @Override
- public void resetLearning() {
- m_n = 1;
- m_p = 1;
- m_s = 0;
- m_psmin = Double.MAX_VALUE;
- m_pmin = Double.MAX_VALUE;
- m_smin = Double.MAX_VALUE;
+ if (m_p + m_s <= m_psmin) {
+ m_pmin = m_p;
+ m_smin = m_s;
+ m_psmin = m_p + m_s;
}
- @Override
- public void input(double prediction) {
- // prediction must be 1 or 0
- // It monitors the error rate
- if (this.isChangeDetected) {
- resetLearning();
- }
- m_p = m_p + (prediction - m_p) / (double) m_n;
- m_s = Math.sqrt(m_p * (1 - m_p) / (double) m_n);
+ if (m_n > this.minNumInstancesOption.getValue() && m_p + m_s > m_pmin + 3 * m_smin) {
+ this.isChangeDetected = true;
+ // resetLearning();
+ } else
+ this.isWarningZone = m_p + m_s > m_pmin + 2 * m_smin;
+ }
- m_n++;
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- // System.out.print(prediction + " " + m_n + " " + (m_p+m_s) + " ");
- this.estimation = m_p;
- this.isChangeDetected = false;
- this.isWarningZone = false;
- this.delay = 0;
-
- if (m_n < this.minNumInstancesOption.getValue()) {
- return;
- }
-
- if (m_p + m_s <= m_psmin) {
- m_pmin = m_p;
- m_smin = m_s;
- m_psmin = m_p + m_s;
- }
-
- if (m_n > this.minNumInstancesOption.getValue() && m_p + m_s > m_pmin + 3 * m_smin) {
- this.isChangeDetected = true;
- //resetLearning();
- } else
- this.isWarningZone = m_p + m_s > m_pmin + 2 * m_smin;
- }
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
-
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/EDDM.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/EDDM.java
index 2ab1022..b845120 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/EDDM.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/EDDM.java
@@ -25,114 +25,116 @@
/**
* Drift detection method based in EDDM method of Manuel Baena et al.
- *
- * <p>Early Drift Detection Method. Manuel Baena-Garcia, Jose Del Campo-Avila,
- * Raúl Fidalgo, Albert Bifet, Ricard Gavalda, Rafael Morales-Bueno. In Fourth
- * International Workshop on Knowledge Discovery from Data Streams, 2006.</p>
- *
+ *
+ * <p>
+ * Early Drift Detection Method. Manuel Baena-Garcia, Jose Del Campo-Avila, Raúl
+ * Fidalgo, Albert Bifet, Ricard Gavalda, Rafael Morales-Bueno. In Fourth
+ * International Workshop on Knowledge Discovery from Data Streams, 2006.
+ * </p>
+ *
* @author Manuel Baena (mbaena@lcc.uma.es)
* @version $Revision: 7 $
*/
public class EDDM extends AbstractChangeDetector {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 140980267062162000L;
+ private static final long serialVersionUID = 140980267062162000L;
- private static final double FDDM_OUTCONTROL = 0.9;
+ private static final double FDDM_OUTCONTROL = 0.9;
- private static final double FDDM_WARNING = 0.95;
+ private static final double FDDM_WARNING = 0.95;
- private static final double FDDM_MINNUMINSTANCES = 30;
+ private static final double FDDM_MINNUMINSTANCES = 30;
- private double m_numErrors;
+ private double m_numErrors;
- private int m_minNumErrors = 30;
+ private int m_minNumErrors = 30;
- private int m_n;
+ private int m_n;
- private int m_d;
+ private int m_d;
- private int m_lastd;
+ private int m_lastd;
- private double m_mean;
+ private double m_mean;
- private double m_stdTemp;
+ private double m_stdTemp;
- private double m_m2smax;
+ private double m_m2smax;
- public EDDM() {
- resetLearning();
+ public EDDM() {
+ resetLearning();
+ }
+
+ @Override
+ public void resetLearning() {
+ m_n = 1;
+ m_numErrors = 0;
+ m_d = 0;
+ m_lastd = 0;
+ m_mean = 0.0;
+ m_stdTemp = 0.0;
+ m_m2smax = 0.0;
+ this.estimation = 0.0;
+ }
+
+ @Override
+ public void input(double prediction) {
+ // prediction must be 1 or 0
+ // It monitors the error rate
+ // System.out.print(prediction + " " + m_n + " " + probability + " ");
+ if (this.isChangeDetected) {
+ resetLearning();
}
+ this.isChangeDetected = false;
- @Override
- public void resetLearning() {
- m_n = 1;
- m_numErrors = 0;
- m_d = 0;
- m_lastd = 0;
- m_mean = 0.0;
- m_stdTemp = 0.0;
- m_m2smax = 0.0;
- this.estimation = 0.0;
- }
+ m_n++;
+ if (prediction == 1.0) {
+ this.isWarningZone = false;
+ this.delay = 0;
+ m_numErrors += 1;
+ m_lastd = m_d;
+ m_d = m_n - 1;
+ int distance = m_d - m_lastd;
+ double oldmean = m_mean;
+ m_mean = m_mean + ((double) distance - m_mean) / m_numErrors;
+ this.estimation = m_mean;
+ m_stdTemp = m_stdTemp + (distance - m_mean) * (distance - oldmean);
+ double std = Math.sqrt(m_stdTemp / m_numErrors);
+ double m2s = m_mean + 2 * std;
- @Override
- public void input(double prediction) {
- // prediction must be 1 or 0
- // It monitors the error rate
- // System.out.print(prediction + " " + m_n + " " + probability + " ");
- if (this.isChangeDetected) {
- resetLearning();
+ if (m2s > m_m2smax) {
+ if (m_n > FDDM_MINNUMINSTANCES) {
+ m_m2smax = m2s;
}
- this.isChangeDetected = false;
-
- m_n++;
- if (prediction == 1.0) {
- this.isWarningZone = false;
- this.delay = 0;
- m_numErrors += 1;
- m_lastd = m_d;
- m_d = m_n - 1;
- int distance = m_d - m_lastd;
- double oldmean = m_mean;
- m_mean = m_mean + ((double) distance - m_mean) / m_numErrors;
- this.estimation = m_mean;
- m_stdTemp = m_stdTemp + (distance - m_mean) * (distance - oldmean);
- double std = Math.sqrt(m_stdTemp / m_numErrors);
- double m2s = m_mean + 2 * std;
-
- if (m2s > m_m2smax) {
- if (m_n > FDDM_MINNUMINSTANCES) {
- m_m2smax = m2s;
- }
- //m_lastLevel = DDM_INCONTROL_LEVEL;
- // System.out.print(1 + " ");
- } else {
- double p = m2s / m_m2smax;
- // System.out.print(p + " ");
- if (m_n > FDDM_MINNUMINSTANCES && m_numErrors > m_minNumErrors
- && p < FDDM_OUTCONTROL) {
- //System.out.println(m_mean + ",D");
- this.isChangeDetected = true;
- //resetLearning();
- } else {
- this.isWarningZone = m_n > FDDM_MINNUMINSTANCES
- && m_numErrors > m_minNumErrors && p < FDDM_WARNING;
- }
- }
+ // m_lastLevel = DDM_INCONTROL_LEVEL;
+ // System.out.print(1 + " ");
+ } else {
+ double p = m2s / m_m2smax;
+ // System.out.print(p + " ");
+ if (m_n > FDDM_MINNUMINSTANCES && m_numErrors > m_minNumErrors
+ && p < FDDM_OUTCONTROL) {
+ // System.out.println(m_mean + ",D");
+ this.isChangeDetected = true;
+ // resetLearning();
+ } else {
+ this.isWarningZone = m_n > FDDM_MINNUMINSTANCES
+ && m_numErrors > m_minNumErrors && p < FDDM_WARNING;
}
+ }
}
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/EWMAChartDM.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/EWMAChartDM.java
index 7f98099..6394259 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/EWMAChartDM.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/EWMAChartDM.java
@@ -28,95 +28,96 @@
/**
* Drift detection method based in EWMA Charts of Ross, Adams, Tasoulis and Hand
* 2012
- *
- *
+ *
+ *
* @author Manuel Baena (mbaena@lcc.uma.es)
* @version $Revision: 7 $
*/
public class EWMAChartDM extends AbstractChangeDetector {
- private static final long serialVersionUID = -3518369648142099719L;
+ private static final long serialVersionUID = -3518369648142099719L;
- //private static final int DDM_MIN_NUM_INST = 30;
- public IntOption minNumInstancesOption = new IntOption(
- "minNumInstances",
- 'n',
- "The minimum number of instances before permitting detecting change.",
- 30, 0, Integer.MAX_VALUE);
+ // private static final int DDM_MIN_NUM_INST = 30;
+ public IntOption minNumInstancesOption = new IntOption(
+ "minNumInstances",
+ 'n',
+ "The minimum number of instances before permitting detecting change.",
+ 30, 0, Integer.MAX_VALUE);
- public FloatOption lambdaOption = new FloatOption("lambda", 'l',
- "Lambda parameter of the EWMA Chart Method", 0.2, 0.0, Float.MAX_VALUE);
+ public FloatOption lambdaOption = new FloatOption("lambda", 'l',
+ "Lambda parameter of the EWMA Chart Method", 0.2, 0.0, Float.MAX_VALUE);
- private double m_n;
+ private double m_n;
- private double m_sum;
-
- private double m_p;
-
- private double m_s;
-
- private double lambda;
-
- private double z_t;
+ private double m_sum;
- public EWMAChartDM() {
- resetLearning();
+ private double m_p;
+
+ private double m_s;
+
+ private double lambda;
+
+ private double z_t;
+
+ public EWMAChartDM() {
+ resetLearning();
+ }
+
+ @Override
+ public void resetLearning() {
+ m_n = 1.0;
+ m_sum = 0.0;
+ m_p = 0.0;
+ m_s = 0.0;
+ z_t = 0.0;
+ lambda = this.lambdaOption.getValue();
+ }
+
+ @Override
+ public void input(double prediction) {
+ // prediction must be 1 or 0
+ // It monitors the error rate
+ if (this.isChangeDetected) {
+ resetLearning();
}
- @Override
- public void resetLearning() {
- m_n = 1.0;
- m_sum = 0.0;
- m_p = 0.0;
- m_s = 0.0;
- z_t = 0.0;
- lambda = this.lambdaOption.getValue();
+ m_sum += prediction;
+
+ m_p = m_sum / m_n; // m_p + (prediction - m_p) / (double) (m_n+1);
+
+ m_s = Math.sqrt(m_p * (1.0 - m_p) * lambda * (1.0 - Math.pow(1.0 - lambda, 2.0 * m_n)) / (2.0 - lambda));
+
+ m_n++;
+
+ z_t += lambda * (prediction - z_t);
+
+ double L_t = 3.97 - 6.56 * m_p + 48.73 * Math.pow(m_p, 3) - 330.13 * Math.pow(m_p, 5) + 848.18 * Math.pow(m_p, 7); // %1
+ // FP
+ this.estimation = m_p;
+ this.isChangeDetected = false;
+ this.isWarningZone = false;
+ this.delay = 0;
+
+ if (m_n < this.minNumInstancesOption.getValue()) {
+ return;
}
- @Override
- public void input(double prediction) {
- // prediction must be 1 or 0
- // It monitors the error rate
- if (this.isChangeDetected) {
- resetLearning();
- }
-
- m_sum += prediction;
-
- m_p = m_sum/m_n; // m_p + (prediction - m_p) / (double) (m_n+1);
-
- m_s = Math.sqrt( m_p * (1.0 - m_p)* lambda * (1.0 - Math.pow(1.0 - lambda, 2.0 * m_n)) / (2.0 - lambda));
-
- m_n++;
-
- z_t += lambda * (prediction - z_t);
-
- double L_t = 3.97 - 6.56 * m_p + 48.73 * Math.pow(m_p, 3) - 330.13 * Math.pow(m_p, 5) + 848.18 * Math.pow(m_p, 7); //%1 FP
- this.estimation = m_p;
- this.isChangeDetected = false;
- this.isWarningZone = false;
- this.delay = 0;
-
- if (m_n < this.minNumInstancesOption.getValue()) {
- return;
- }
-
- if (m_n > this.minNumInstancesOption.getValue() && z_t > m_p + L_t * m_s) {
- this.isChangeDetected = true;
- //resetLearning();
- } else {
- this.isWarningZone = z_t > m_p + 0.5 * L_t * m_s;
- }
+ if (m_n > this.minNumInstancesOption.getValue() && z_t > m_p + L_t * m_s) {
+ this.isChangeDetected = true;
+ // resetLearning();
+ } else {
+ this.isWarningZone = z_t > m_p + 0.5 * L_t * m_s;
}
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/GeometricMovingAverageDM.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/GeometricMovingAverageDM.java
index 5ed44e5..fc4335f 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/GeometricMovingAverageDM.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/GeometricMovingAverageDM.java
@@ -27,82 +27,82 @@
/**
* Drift detection method based in Geometric Moving Average Test
- *
- *
+ *
+ *
* @author Manuel Baena (mbaena@lcc.uma.es)
* @version $Revision: 7 $
*/
public class GeometricMovingAverageDM extends AbstractChangeDetector {
- private static final long serialVersionUID = -3518369648142099719L;
+ private static final long serialVersionUID = -3518369648142099719L;
- public IntOption minNumInstancesOption = new IntOption(
- "minNumInstances",
- 'n',
- "The minimum number of instances before permitting detecting change.",
- 30, 0, Integer.MAX_VALUE);
+ public IntOption minNumInstancesOption = new IntOption(
+ "minNumInstances",
+ 'n',
+ "The minimum number of instances before permitting detecting change.",
+ 30, 0, Integer.MAX_VALUE);
- public FloatOption lambdaOption = new FloatOption("lambda", 'l',
- "Threshold parameter of the Geometric Moving Average Test", 1, 0.0, Float.MAX_VALUE);
+ public FloatOption lambdaOption = new FloatOption("lambda", 'l',
+ "Threshold parameter of the Geometric Moving Average Test", 1, 0.0, Float.MAX_VALUE);
- public FloatOption alphaOption = new FloatOption("alpha", 'a',
- "Alpha parameter of the Geometric Moving Average Test", .99, 0.0, 1.0);
+ public FloatOption alphaOption = new FloatOption("alpha", 'a',
+ "Alpha parameter of the Geometric Moving Average Test", .99, 0.0, 1.0);
- private double m_n;
+ private double m_n;
- private double sum;
+ private double sum;
- private double x_mean;
-
- private double alpha;
+ private double x_mean;
- private double lambda;
+ private double alpha;
- public GeometricMovingAverageDM() {
- resetLearning();
+ private double lambda;
+
+ public GeometricMovingAverageDM() {
+ resetLearning();
+ }
+
+ @Override
+ public void resetLearning() {
+ m_n = 1.0;
+ x_mean = 0.0;
+ sum = 0.0;
+ alpha = this.alphaOption.getValue();
+ lambda = this.lambdaOption.getValue();
+ }
+
+ @Override
+ public void input(double x) {
+ // It monitors the error rate
+ if (this.isChangeDetected) {
+ resetLearning();
}
- @Override
- public void resetLearning() {
- m_n = 1.0;
- x_mean = 0.0;
- sum = 0.0;
- alpha = this.alphaOption.getValue();
- lambda = this.lambdaOption.getValue();
+ x_mean = x_mean + (x - x_mean) / m_n;
+ sum = alpha * sum + (1.0 - alpha) * (x - x_mean);
+ m_n++;
+ this.estimation = x_mean;
+ this.isChangeDetected = false;
+ this.isWarningZone = false;
+ this.delay = 0;
+
+ if (m_n < this.minNumInstancesOption.getValue()) {
+ return;
}
- @Override
- public void input(double x) {
- // It monitors the error rate
- if (this.isChangeDetected) {
- resetLearning();
- }
-
- x_mean = x_mean + (x - x_mean) / m_n;
- sum = alpha * sum + (1.0- alpha) * (x - x_mean);
- m_n++;
- this.estimation = x_mean;
- this.isChangeDetected = false;
- this.isWarningZone = false;
- this.delay = 0;
-
- if (m_n < this.minNumInstancesOption.getValue()) {
- return;
- }
-
- if (sum > this.lambda) {
- this.isChangeDetected = true;
- }
+ if (sum > this.lambda) {
+ this.isChangeDetected = true;
}
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/PageHinkleyDM.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/PageHinkleyDM.java
index 7a94a17..a01021a 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/PageHinkleyDM.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/driftdetection/PageHinkleyDM.java
@@ -27,88 +27,88 @@
/**
* Drift detection method based in Page Hinkley Test.
- *
- *
+ *
+ *
* @author Manuel Baena (mbaena@lcc.uma.es)
* @version $Revision: 7 $
*/
public class PageHinkleyDM extends AbstractChangeDetector {
- private static final long serialVersionUID = -3518369648142099719L;
+ private static final long serialVersionUID = -3518369648142099719L;
- public IntOption minNumInstancesOption = new IntOption(
- "minNumInstances",
- 'n',
- "The minimum number of instances before permitting detecting change.",
- 30, 0, Integer.MAX_VALUE);
+ public IntOption minNumInstancesOption = new IntOption(
+ "minNumInstances",
+ 'n',
+ "The minimum number of instances before permitting detecting change.",
+ 30, 0, Integer.MAX_VALUE);
- public FloatOption deltaOption = new FloatOption("delta", 'd',
- "Delta parameter of the Page Hinkley Test", 0.005, 0.0, 1.0);
+ public FloatOption deltaOption = new FloatOption("delta", 'd',
+ "Delta parameter of the Page Hinkley Test", 0.005, 0.0, 1.0);
- public FloatOption lambdaOption = new FloatOption("lambda", 'l',
- "Lambda parameter of the Page Hinkley Test", 50, 0.0, Float.MAX_VALUE);
+ public FloatOption lambdaOption = new FloatOption("lambda", 'l',
+ "Lambda parameter of the Page Hinkley Test", 50, 0.0, Float.MAX_VALUE);
- public FloatOption alphaOption = new FloatOption("alpha", 'a',
- "Alpha parameter of the Page Hinkley Test", 1 - 0.0001, 0.0, 1.0);
+ public FloatOption alphaOption = new FloatOption("alpha", 'a',
+ "Alpha parameter of the Page Hinkley Test", 1 - 0.0001, 0.0, 1.0);
- private int m_n;
+ private int m_n;
- private double sum;
+ private double sum;
- private double x_mean;
+ private double x_mean;
- private double alpha;
+ private double alpha;
- private double delta;
+ private double delta;
- private double lambda;
+ private double lambda;
- public PageHinkleyDM() {
- resetLearning();
+ public PageHinkleyDM() {
+ resetLearning();
+ }
+
+ @Override
+ public void resetLearning() {
+ m_n = 1;
+ x_mean = 0.0;
+ sum = 0.0;
+ delta = this.deltaOption.getValue();
+ alpha = this.alphaOption.getValue();
+ lambda = this.lambdaOption.getValue();
+ }
+
+ @Override
+ public void input(double x) {
+ // It monitors the error rate
+ if (this.isChangeDetected) {
+ resetLearning();
}
- @Override
- public void resetLearning() {
- m_n = 1;
- x_mean = 0.0;
- sum = 0.0;
- delta = this.deltaOption.getValue();
- alpha = this.alphaOption.getValue();
- lambda = this.lambdaOption.getValue();
+ x_mean = x_mean + (x - x_mean) / (double) m_n;
+ sum = this.alpha * sum + (x - x_mean - this.delta);
+ m_n++;
+ this.estimation = x_mean;
+ this.isChangeDetected = false;
+ this.isWarningZone = false;
+ this.delay = 0;
+
+ if (m_n < this.minNumInstancesOption.getValue()) {
+ return;
}
- @Override
- public void input(double x) {
- // It monitors the error rate
- if (this.isChangeDetected) {
- resetLearning();
- }
-
- x_mean = x_mean + (x - x_mean) / (double) m_n;
- sum = this.alpha * sum + (x - x_mean - this.delta);
- m_n++;
- this.estimation = x_mean;
- this.isChangeDetected = false;
- this.isWarningZone = false;
- this.delay = 0;
-
- if (m_n < this.minNumInstancesOption.getValue()) {
- return;
- }
-
- if (sum > this.lambda) {
- this.isChangeDetected = true;
- }
+ if (sum > this.lambda) {
+ this.isChangeDetected = true;
}
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/GiniSplitCriterion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/GiniSplitCriterion.java
index 185b12f..b5135ea 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/GiniSplitCriterion.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/GiniSplitCriterion.java
@@ -26,61 +26,60 @@
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
/**
- * Class for computing splitting criteria using Gini
- * with respect to distributions of class values.
- * The split criterion is used as a parameter on
+ * Class for computing splitting criteria using Gini with respect to
+ * distributions of class values. The split criterion is used as a parameter on
* decision trees and decision stumps.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class GiniSplitCriterion extends AbstractOptionHandler implements
- SplitCriterion {
+ SplitCriterion {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- @Override
- public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) {
- double totalWeight = 0.0;
- double[] distWeights = new double[postSplitDists.length];
- for (int i = 0; i < postSplitDists.length; i++) {
- distWeights[i] = Utils.sum(postSplitDists[i]);
- totalWeight += distWeights[i];
- }
- double gini = 0.0;
- for (int i = 0; i < postSplitDists.length; i++) {
- gini += (distWeights[i] / totalWeight)
- * computeGini(postSplitDists[i], distWeights[i]);
- }
- return 1.0 - gini;
+ @Override
+ public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) {
+ double totalWeight = 0.0;
+ double[] distWeights = new double[postSplitDists.length];
+ for (int i = 0; i < postSplitDists.length; i++) {
+ distWeights[i] = Utils.sum(postSplitDists[i]);
+ totalWeight += distWeights[i];
}
-
- @Override
- public double getRangeOfMerit(double[] preSplitDist) {
- return 1.0;
+ double gini = 0.0;
+ for (int i = 0; i < postSplitDists.length; i++) {
+ gini += (distWeights[i] / totalWeight)
+ * computeGini(postSplitDists[i], distWeights[i]);
}
+ return 1.0 - gini;
+ }
- public static double computeGini(double[] dist, double distSumOfWeights) {
- double gini = 1.0;
- for (double aDist : dist) {
- double relFreq = aDist / distSumOfWeights;
- gini -= relFreq * relFreq;
- }
- return gini;
- }
+ @Override
+ public double getRangeOfMerit(double[] preSplitDist) {
+ return 1.0;
+ }
- public static double computeGini(double[] dist) {
- return computeGini(dist, Utils.sum(dist));
+ public static double computeGini(double[] dist, double distSumOfWeights) {
+ double gini = 1.0;
+ for (double aDist : dist) {
+ double relFreq = aDist / distSumOfWeights;
+ gini -= relFreq * relFreq;
}
+ return gini;
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ public static double computeGini(double[] dist) {
+ return computeGini(dist, Utils.sum(dist));
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
+
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterion.java
index f136996..bcf3c31 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterion.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterion.java
@@ -27,92 +27,91 @@
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
/**
- * Class for computing splitting criteria using information gain
- * with respect to distributions of class values.
- * The split criterion is used as a parameter on
+ * Class for computing splitting criteria using information gain with respect to
+ * distributions of class values. The split criterion is used as a parameter on
* decision trees and decision stumps.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class InfoGainSplitCriterion extends AbstractOptionHandler implements
- SplitCriterion {
+ SplitCriterion {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public FloatOption minBranchFracOption = new FloatOption("minBranchFrac",
- 'f',
- "Minimum fraction of weight required down at least two branches.",
- 0.01, 0.0, 0.5);
+ public FloatOption minBranchFracOption = new FloatOption("minBranchFrac",
+ 'f',
+ "Minimum fraction of weight required down at least two branches.",
+ 0.01, 0.0, 0.5);
- @Override
- public double getMeritOfSplit(double[] preSplitDist,
- double[][] postSplitDists) {
- if (numSubsetsGreaterThanFrac(postSplitDists, this.minBranchFracOption.getValue()) < 2) {
- return Double.NEGATIVE_INFINITY;
- }
- return computeEntropy(preSplitDist) - computeEntropy(postSplitDists);
+ @Override
+ public double getMeritOfSplit(double[] preSplitDist,
+ double[][] postSplitDists) {
+ if (numSubsetsGreaterThanFrac(postSplitDists, this.minBranchFracOption.getValue()) < 2) {
+ return Double.NEGATIVE_INFINITY;
}
+ return computeEntropy(preSplitDist) - computeEntropy(postSplitDists);
+ }
- @Override
- public double getRangeOfMerit(double[] preSplitDist) {
- int numClasses = preSplitDist.length > 2 ? preSplitDist.length : 2;
- return Utils.log2(numClasses);
- }
+ @Override
+ public double getRangeOfMerit(double[] preSplitDist) {
+ int numClasses = preSplitDist.length > 2 ? preSplitDist.length : 2;
+ return Utils.log2(numClasses);
+ }
- public static double computeEntropy(double[] dist) {
- double entropy = 0.0;
- double sum = 0.0;
- for (double d : dist) {
- if (d > 0.0) { // TODO: how small can d be before log2 overflows?
- entropy -= d * Utils.log2(d);
- sum += d;
- }
- }
- return sum > 0.0 ? (entropy + sum * Utils.log2(sum)) / sum : 0.0;
+ public static double computeEntropy(double[] dist) {
+ double entropy = 0.0;
+ double sum = 0.0;
+ for (double d : dist) {
+ if (d > 0.0) { // TODO: how small can d be before log2 overflows?
+ entropy -= d * Utils.log2(d);
+ sum += d;
+ }
}
+ return sum > 0.0 ? (entropy + sum * Utils.log2(sum)) / sum : 0.0;
+ }
- public static double computeEntropy(double[][] dists) {
- double totalWeight = 0.0;
- double[] distWeights = new double[dists.length];
- for (int i = 0; i < dists.length; i++) {
- distWeights[i] = Utils.sum(dists[i]);
- totalWeight += distWeights[i];
- }
- double entropy = 0.0;
- for (int i = 0; i < dists.length; i++) {
- entropy += distWeights[i] * computeEntropy(dists[i]);
- }
- return entropy / totalWeight;
+ public static double computeEntropy(double[][] dists) {
+ double totalWeight = 0.0;
+ double[] distWeights = new double[dists.length];
+ for (int i = 0; i < dists.length; i++) {
+ distWeights[i] = Utils.sum(dists[i]);
+ totalWeight += distWeights[i];
}
+ double entropy = 0.0;
+ for (int i = 0; i < dists.length; i++) {
+ entropy += distWeights[i] * computeEntropy(dists[i]);
+ }
+ return entropy / totalWeight;
+ }
- public static int numSubsetsGreaterThanFrac(double[][] distributions, double minFrac) {
- double totalWeight = 0.0;
- double[] distSums = new double[distributions.length];
- for (int i = 0; i < distSums.length; i++) {
- for (int j = 0; j < distributions[i].length; j++) {
- distSums[i] += distributions[i][j];
- }
- totalWeight += distSums[i];
- }
- int numGreater = 0;
- for (double d : distSums) {
- double frac = d / totalWeight;
- if (frac > minFrac) {
- numGreater++;
- }
- }
- return numGreater;
+ public static int numSubsetsGreaterThanFrac(double[][] distributions, double minFrac) {
+ double totalWeight = 0.0;
+ double[] distSums = new double[distributions.length];
+ for (int i = 0; i < distSums.length; i++) {
+ for (int j = 0; j < distributions[i].length; j++) {
+ distSums[i] += distributions[i][j];
+ }
+ totalWeight += distSums[i];
}
+ int numGreater = 0;
+ for (double d : distSums) {
+ double frac = d / totalWeight;
+ if (frac > minFrac) {
+ numGreater++;
+ }
+ }
+ return numGreater;
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterionMultilabel.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterionMultilabel.java
index 60e1e1c..c06754a 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterionMultilabel.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterionMultilabel.java
@@ -26,29 +26,31 @@
* Class for computing splitting criteria using information gain with respect to
* distributions of class values for Multilabel data. The split criterion is
* used as a parameter on decision trees and decision stumps.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @author Jesse Read (jesse@tsc.uc3m.es)
* @version $Revision: 1 $
*/
public class InfoGainSplitCriterionMultilabel extends InfoGainSplitCriterion {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public static double computeEntropy(double[] dist) {
- double entropy = 0.0;
- double sum = 0.0;
- for (double d : dist) {
- sum += d;
- }
- if (sum > 0.0) {
- for (double num : dist) {
- double d = num / sum;
- if (d > 0.0) { // TODO: how small can d be before log2 overflows?
- entropy -= d * Utils.log2(d) + (1 - d) * Utils.log2(1 - d); //Extension to Multilabel
- }
- }
- }
- return sum > 0.0 ? entropy : 0.0;
+ public static double computeEntropy(double[] dist) {
+ double entropy = 0.0;
+ double sum = 0.0;
+ for (double d : dist) {
+ sum += d;
}
+ if (sum > 0.0) {
+ for (double num : dist) {
+ double d = num / sum;
+ if (d > 0.0) { // TODO: how small can d be before log2 overflows?
+ entropy -= d * Utils.log2(d) + (1 - d) * Utils.log2(1 - d); // Extension
+ // to
+ // Multilabel
+ }
+ }
+ }
+ return sum > 0.0 ? entropy : 0.0;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SDRSplitCriterion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SDRSplitCriterion.java
index a23c93b..d7173d7 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SDRSplitCriterion.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SDRSplitCriterion.java
@@ -21,13 +21,13 @@
*/
public class SDRSplitCriterion extends VarianceReductionSplitCriterion {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public static double computeSD(double[] dist) {
- int N = (int)dist[0];
- double sum = dist[1];
- double sumSq = dist[2];
- return Math.sqrt((sumSq - ((sum * sum)/N))/N);
- }
+ public static double computeSD(double[] dist) {
+ int N = (int) dist[0];
+ double sum = dist[1];
+ double sumSq = dist[2];
+ return Math.sqrt((sumSq - ((sum * sum) / N)) / N);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SplitCriterion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SplitCriterion.java
index eba390e..ce5d661 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SplitCriterion.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SplitCriterion.java
@@ -23,34 +23,35 @@
import com.yahoo.labs.samoa.moa.options.OptionHandler;
/**
- * Interface for computing splitting criteria.
- * with respect to distributions of class values.
- * The split criterion is used as a parameter on
- * decision trees and decision stumps.
- * The two split criteria most used are
- * Information Gain and Gini.
- *
+ * Interface for computing splitting criteria. with respect to distributions of
+ * class values. The split criterion is used as a parameter on decision trees
+ * and decision stumps. The two split criteria most used are Information Gain
+ * and Gini.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
- * @version $Revision: 7 $
+ * @version $Revision: 7 $
*/
public interface SplitCriterion extends OptionHandler {
- /**
- * Computes the merit of splitting for a given
- * ditribution before the split and after it.
- *
- * @param preSplitDist the class distribution before the split
- * @param postSplitDists the class distribution after the split
- * @return value of the merit of splitting
- */
- public double getMeritOfSplit(double[] preSplitDist,
- double[][] postSplitDists);
+ /**
+ * Computes the merit of splitting for a given ditribution before the split
+ * and after it.
+ *
+ * @param preSplitDist
+ * the class distribution before the split
+ * @param postSplitDists
+ * the class distribution after the split
+ * @return value of the merit of splitting
+ */
+ public double getMeritOfSplit(double[] preSplitDist,
+ double[][] postSplitDists);
- /**
- * Computes the range of splitting merit
- *
- * @param preSplitDist the class distribution before the split
- * @return value of the range of splitting merit
- */
- public double getRangeOfMerit(double[] preSplitDist);
+ /**
+ * Computes the range of splitting merit
+ *
+ * @param preSplitDist
+ * the class distribution before the split
+ * @return value of the range of splitting merit
+ */
+ public double getRangeOfMerit(double[] preSplitDist);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/VarianceReductionSplitCriterion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/VarianceReductionSplitCriterion.java
index c5ca348..6aa66ba 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/VarianceReductionSplitCriterion.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/VarianceReductionSplitCriterion.java
@@ -26,74 +26,69 @@
public class VarianceReductionSplitCriterion extends AbstractOptionHandler implements SplitCriterion {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
-/* @Override
- public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) {
-
- double N = preSplitDist[0];
- double SDR = computeSD(preSplitDist);
+ /*
+ * @Override public double getMeritOfSplit(double[] preSplitDist, double[][]
+ * postSplitDists) {
+ *
+ * double N = preSplitDist[0]; double SDR = computeSD(preSplitDist);
+ *
+ * // System.out.print("postSplitDists.length"+postSplitDists.length+"\n");
+ * for(int i = 0; i < postSplitDists.length; i++) { double Ni =
+ * postSplitDists[i][0]; SDR -= (Ni/N)*computeSD(postSplitDists[i]); }
+ *
+ * return SDR; }
+ */
- // System.out.print("postSplitDists.length"+postSplitDists.length+"\n");
- for(int i = 0; i < postSplitDists.length; i++)
- {
- double Ni = postSplitDists[i][0];
- SDR -= (Ni/N)*computeSD(postSplitDists[i]);
- }
+ @Override
+ public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) {
+ double SDR = 0.0;
+ double N = preSplitDist[0];
+ int count = 0;
- return SDR;
- }*/
-
- @Override
- public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) {
- double SDR=0.0;
- double N = preSplitDist[0];
- int count = 0;
-
- for (int i1 = 0; i1 < postSplitDists.length; i1++) {
- double[] postSplitDist = postSplitDists[i1];
- double Ni = postSplitDist[0];
- if (Ni >= 5.0) {
- count = count + 1;
- }
- }
-
- if(count == postSplitDists.length){
- SDR = computeSD(preSplitDist);
- for(int i = 0; i < postSplitDists.length; i++)
- {
- double Ni = postSplitDists[i][0];
- SDR -= (Ni/N)*computeSD(postSplitDists[i]);
- }
- }
- return SDR;
- }
-
-
-
- @Override
- public double getRangeOfMerit(double[] preSplitDist) {
- return 1;
+ for (int i1 = 0; i1 < postSplitDists.length; i1++) {
+ double[] postSplitDist = postSplitDists[i1];
+ double Ni = postSplitDist[0];
+ if (Ni >= 5.0) {
+ count = count + 1;
+ }
}
- public static double computeSD(double[] dist) {
-
- int N = (int)dist[0];
- double sum = dist[1];
- double sumSq = dist[2];
- // return Math.sqrt((sumSq - ((sum * sum)/N))/N);
- return (sumSq - ((sum * sum)/N))/N;
+ if (count == postSplitDists.length) {
+ SDR = computeSD(preSplitDist);
+ for (int i = 0; i < postSplitDists.length; i++)
+ {
+ double Ni = postSplitDists[i][0];
+ SDR -= (Ni / N) * computeSD(postSplitDists[i]);
+ }
}
+ return SDR;
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public double getRangeOfMerit(double[] preSplitDist) {
+ return 1;
+ }
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- // TODO Auto-generated method stub
- }
-
+ public static double computeSD(double[] dist) {
+
+ int N = (int) dist[0];
+ double sum = dist[1];
+ double sumSq = dist[2];
+ // return Math.sqrt((sumSq - ((sum * sum)/N))/N);
+ return (sumSq - ((sum * sum) / N)) / N;
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
+
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // TODO Auto-generated method stub
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/functions/MajorityClass.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/functions/MajorityClass.java
index 6c2c807..ca3e452 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/functions/MajorityClass.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/functions/MajorityClass.java
@@ -28,57 +28,57 @@
/**
* Majority class learner. This is the simplest classifier.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class MajorityClass extends AbstractClassifier {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- @Override
- public String getPurposeString() {
- return "Majority class classifier: always predicts the class that has been observed most frequently the in the training data.";
+ @Override
+ public String getPurposeString() {
+ return "Majority class classifier: always predicts the class that has been observed most frequently the in the training data.";
+ }
+
+ protected DoubleVector observedClassDistribution;
+
+ @Override
+ public void resetLearningImpl() {
+ this.observedClassDistribution = new DoubleVector();
+ }
+
+ @Override
+ public void trainOnInstanceImpl(Instance inst) {
+ this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight());
+ }
+
+ public double[] getVotesForInstance(Instance i) {
+ return this.observedClassDistribution.getArrayCopy();
+ }
+
+ @Override
+ protected Measurement[] getModelMeasurementsImpl() {
+ return null;
+ }
+
+ @Override
+ public void getModelDescription(StringBuilder out, int indent) {
+ StringUtils.appendIndented(out, indent, "Predicted majority ");
+ out.append(getClassNameString());
+ out.append(" = ");
+ out.append(getClassLabelString(this.observedClassDistribution.maxIndex()));
+ StringUtils.appendNewline(out);
+ for (int i = 0; i < this.observedClassDistribution.numValues(); i++) {
+ StringUtils.appendIndented(out, indent, "Observed weight of ");
+ out.append(getClassLabelString(i));
+ out.append(": ");
+ out.append(this.observedClassDistribution.getValue(i));
+ StringUtils.appendNewline(out);
}
+ }
- protected DoubleVector observedClassDistribution;
-
- @Override
- public void resetLearningImpl() {
- this.observedClassDistribution = new DoubleVector();
- }
-
- @Override
- public void trainOnInstanceImpl(Instance inst) {
- this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight());
- }
-
- public double[] getVotesForInstance(Instance i) {
- return this.observedClassDistribution.getArrayCopy();
- }
-
- @Override
- protected Measurement[] getModelMeasurementsImpl() {
- return null;
- }
-
- @Override
- public void getModelDescription(StringBuilder out, int indent) {
- StringUtils.appendIndented(out, indent, "Predicted majority ");
- out.append(getClassNameString());
- out.append(" = ");
- out.append(getClassLabelString(this.observedClassDistribution.maxIndex()));
- StringUtils.appendNewline(out);
- for (int i = 0; i < this.observedClassDistribution.numValues(); i++) {
- StringUtils.appendIndented(out, indent, "Observed weight of ");
- out.append(getClassLabelString(i));
- out.append(": ");
- out.append(this.observedClassDistribution.getValue(i));
- StringUtils.appendNewline(out);
- }
- }
-
- public boolean isRandomizable() {
- return false;
- }
+ public boolean isRandomizable() {
+ return false;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/Predicate.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/Predicate.java
index d6bdcaa..6ebd3c5 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/Predicate.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/Predicate.java
@@ -24,10 +24,10 @@
/**
* Interface for a predicate (a feature) in rules.
- *
+ *
*/
public interface Predicate {
-
- public boolean evaluate(Instance instance);
-
+
+ public boolean evaluate(Instance instance);
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/attributeclassobservers/FIMTDDNumericAttributeClassLimitObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/attributeclassobservers/FIMTDDNumericAttributeClassLimitObserver.java
index 8b4c332..6633da8 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/attributeclassobservers/FIMTDDNumericAttributeClassLimitObserver.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/attributeclassobservers/FIMTDDNumericAttributeClassLimitObserver.java
@@ -23,104 +23,99 @@
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers.FIMTDDNumericAttributeClassObserver;
-
public class FIMTDDNumericAttributeClassLimitObserver extends FIMTDDNumericAttributeClassObserver {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 1L;
- protected int maxNodes;
- //public IntOption maxNodesOption = new IntOption("maxNodesOption", 'z', "Maximum number of nodes", 50, 0, Integer.MAX_VALUE);
+ private static final long serialVersionUID = 1L;
+ protected int maxNodes;
+ // public IntOption maxNodesOption = new IntOption("maxNodesOption", 'z',
+ // "Maximum number of nodes", 50, 0, Integer.MAX_VALUE);
+ protected int numNodes;
- protected int numNodes;
-
- public int getMaxNodes() {
- return this.maxNodes;
- }
-
- public void setMaxNodes(int maxNodes) {
- this.maxNodes = maxNodes;
- }
+ public int getMaxNodes() {
+ return this.maxNodes;
+ }
- @Override
- public void observeAttributeClass(double attVal, double classVal, double weight) {
- if (Double.isNaN(attVal)) { //Instance.isMissingValue(attVal)
- } else {
- if (this.root == null) {
- //maxNodes=maxNodesOption.getValue();
- maxNodes = 50;
- this.root = new FIMTDDNumericAttributeClassLimitObserver.Node(attVal, classVal, weight);
- } else {
- this.root.insertValue(attVal, classVal, weight);
- }
- }
- }
-
- protected class Node extends FIMTDDNumericAttributeClassObserver.Node {
- /**
+ public void setMaxNodes(int maxNodes) {
+ this.maxNodes = maxNodes;
+ }
+
+ @Override
+ public void observeAttributeClass(double attVal, double classVal, double weight) {
+ if (Double.isNaN(attVal)) { // Instance.isMissingValue(attVal)
+ } else {
+ if (this.root == null) {
+ // maxNodes=maxNodesOption.getValue();
+ maxNodes = 50;
+ this.root = new FIMTDDNumericAttributeClassLimitObserver.Node(attVal, classVal, weight);
+ } else {
+ this.root.insertValue(attVal, classVal, weight);
+ }
+ }
+ }
+
+ protected class Node extends FIMTDDNumericAttributeClassObserver.Node {
+ /**
*
*/
- private static final long serialVersionUID = -4484141636424708465L;
+ private static final long serialVersionUID = -4484141636424708465L;
+ public Node(double val, double label, double weight) {
+ super(val, label, weight);
+ }
+ protected Node root = null;
- public Node(double val, double label, double weight) {
- super(val, label, weight);
- }
+ /**
+ * Insert a new value into the tree, updating both the sum of values and sum
+ * of squared values arrays
+ */
+ @Override
+ public void insertValue(double val, double label, double weight) {
- protected Node root = null;
-
-
-
- /**
- * Insert a new value into the tree, updating both the sum of values and
- * sum of squared values arrays
- */
- @Override
- public void insertValue(double val, double label, double weight) {
-
- // If the new value equals the value stored in a node, update
- // the left (<=) node information
- if (val == this.cut_point)
- {
- this.leftStatistics.addToValue(0,1);
- this.leftStatistics.addToValue(1,label);
- this.leftStatistics.addToValue(2,label*label);
- }
- // If the new value is less than the value in a node, update the
- // left distribution and send the value down to the left child node.
- // If no left child exists, create one
- else if (val <= this.cut_point) {
- this.leftStatistics.addToValue(0,1);
- this.leftStatistics.addToValue(1,label);
- this.leftStatistics.addToValue(2,label*label);
- if (this.left == null) {
- if(numNodes<maxNodes){
- this.left = new Node(val, label, weight);
- ++numNodes;
- }
- } else {
- this.left.insertValue(val, label, weight);
- }
- }
- // If the new value is greater than the value in a node, update the
- // right (>) distribution and send the value down to the right child node.
- // If no right child exists, create one
- else { // val > cut_point
- this.rightStatistics.addToValue(0,1);
- this.rightStatistics.addToValue(1,label);
- this.rightStatistics.addToValue(2,label*label);
- if (this.right == null) {
- if(numNodes<maxNodes){
- this.right = new Node(val, label, weight);
- ++numNodes;
- }
- } else {
- this.right.insertValue(val, label, weight);
- }
- }
- }
- }
+ // If the new value equals the value stored in a node, update
+ // the left (<=) node information
+ if (val == this.cut_point)
+ {
+ this.leftStatistics.addToValue(0, 1);
+ this.leftStatistics.addToValue(1, label);
+ this.leftStatistics.addToValue(2, label * label);
+ }
+ // If the new value is less than the value in a node, update the
+ // left distribution and send the value down to the left child node.
+ // If no left child exists, create one
+ else if (val <= this.cut_point) {
+ this.leftStatistics.addToValue(0, 1);
+ this.leftStatistics.addToValue(1, label);
+ this.leftStatistics.addToValue(2, label * label);
+ if (this.left == null) {
+ if (numNodes < maxNodes) {
+ this.left = new Node(val, label, weight);
+ ++numNodes;
+ }
+ } else {
+ this.left.insertValue(val, label, weight);
+ }
+ }
+ // If the new value is greater than the value in a node, update the
+ // right (>) distribution and send the value down to the right child node.
+ // If no right child exists, create one
+ else { // val > cut_point
+ this.rightStatistics.addToValue(0, 1);
+ this.rightStatistics.addToValue(1, label);
+ this.rightStatistics.addToValue(2, label * label);
+ if (this.right == null) {
+ if (numNodes < maxNodes) {
+ this.right = new Node(val, label, weight);
+ ++numNodes;
+ }
+ } else {
+ this.right.insertValue(val, label, weight);
+ }
+ }
+ }
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/conditionaltests/NumericAttributeBinaryRulePredicate.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/conditionaltests/NumericAttributeBinaryRulePredicate.java
index 5f27c40..94edc33 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/conditionaltests/NumericAttributeBinaryRulePredicate.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/conditionaltests/NumericAttributeBinaryRulePredicate.java
@@ -47,134 +47,135 @@
/**
* Numeric binary conditional test for instances to use to split nodes in
* AMRules.
- *
+ *
* @version $Revision: 1 $
*/
public class NumericAttributeBinaryRulePredicate extends InstanceConditionalBinaryTest implements Predicate {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected int attIndex;
+ protected int attIndex;
- protected double attValue;
+ protected double attValue;
- protected int operator; // 0 =, 1<=, 2>
+ protected int operator; // 0 =, 1<=, 2>
- public NumericAttributeBinaryRulePredicate() {
- this(0,0,0);
+ public NumericAttributeBinaryRulePredicate() {
+ this(0, 0, 0);
+ }
+
+ public NumericAttributeBinaryRulePredicate(int attIndex, double attValue,
+ int operator) {
+ this.attIndex = attIndex;
+ this.attValue = attValue;
+ this.operator = operator;
+ }
+
+ public NumericAttributeBinaryRulePredicate(NumericAttributeBinaryRulePredicate oldTest) {
+ this(oldTest.attIndex, oldTest.attValue, oldTest.operator);
+ }
+
+ @Override
+ public int branchForInstance(Instance inst) {
+ int instAttIndex = this.attIndex < inst.classIndex() ? this.attIndex
+ : this.attIndex + 1;
+ if (inst.isMissing(instAttIndex)) {
+ return -1;
}
- public NumericAttributeBinaryRulePredicate(int attIndex, double attValue,
- int operator) {
- this.attIndex = attIndex;
- this.attValue = attValue;
- this.operator = operator;
+ double v = inst.value(instAttIndex);
+ int ret = 0;
+ switch (this.operator) {
+ case 0:
+ ret = (v == this.attValue) ? 0 : 1;
+ break;
+ case 1:
+ ret = (v <= this.attValue) ? 0 : 1;
+ break;
+ case 2:
+ ret = (v > this.attValue) ? 0 : 1;
}
-
- public NumericAttributeBinaryRulePredicate(NumericAttributeBinaryRulePredicate oldTest) {
- this(oldTest.attIndex, oldTest.attValue, oldTest.operator);
- }
+ return ret;
+ }
- @Override
- public int branchForInstance(Instance inst) {
- int instAttIndex = this.attIndex < inst.classIndex() ? this.attIndex
- : this.attIndex + 1;
- if (inst.isMissing(instAttIndex)) {
- return -1;
- }
- double v = inst.value(instAttIndex);
- int ret = 0;
- switch (this.operator) {
- case 0:
- ret = (v == this.attValue) ? 0 : 1;
- break;
- case 1:
- ret = (v <= this.attValue) ? 0 : 1;
- break;
- case 2:
- ret = (v > this.attValue) ? 0 : 1;
- }
- return ret;
- }
-
- /**
+ /**
*
*/
- @Override
- public String describeConditionForBranch(int branch, InstancesHeader context) {
- if ((branch >= 0) && (branch <= 2)) {
- String compareChar = (branch == 0) ? "=" : (branch == 1) ? "<=" : ">";
- return InstancesHeader.getAttributeNameString(context,
- this.attIndex)
- + ' '
- + compareChar
- + InstancesHeader.getNumericValueString(context,
- this.attIndex, this.attValue);
- }
- throw new IndexOutOfBoundsException();
+ @Override
+ public String describeConditionForBranch(int branch, InstancesHeader context) {
+ if ((branch >= 0) && (branch <= 2)) {
+ String compareChar = (branch == 0) ? "=" : (branch == 1) ? "<=" : ">";
+ return InstancesHeader.getAttributeNameString(context,
+ this.attIndex)
+ + ' '
+ + compareChar
+ + InstancesHeader.getNumericValueString(context,
+ this.attIndex, this.attValue);
}
+ throw new IndexOutOfBoundsException();
+ }
- /**
+ /**
*
*/
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
+
+ @Override
+ public int[] getAttsTestDependsOn() {
+ return new int[] { this.attIndex };
+ }
+
+ public double getSplitValue() {
+ return this.attValue;
+ }
+
+ @Override
+ public boolean evaluate(Instance inst) {
+ return (branchForInstance(inst) == 0);
+ }
+
+ @Override
+ public String toString() {
+ if ((operator >= 0) && (operator <= 2)) {
+ String compareChar = (operator == 0) ? "=" : (operator == 1) ? "<=" : ">";
+ // int equalsBranch = this.equalsPassesTest ? 0 : 1;
+ return "x" + this.attIndex
+ + ' '
+ + compareChar
+ + ' '
+ + this.attValue;
+ }
+ throw new IndexOutOfBoundsException();
+ }
+
+ public boolean isEqual(NumericAttributeBinaryRulePredicate predicate) {
+ return (this.attIndex == predicate.attIndex
+ && this.attValue == predicate.attValue
+ && this.operator == predicate.operator);
+ }
+
+ public boolean isUsingSameAttribute(NumericAttributeBinaryRulePredicate predicate) {
+ return (this.attIndex == predicate.attIndex
+ && this.operator == predicate.operator);
+ }
+
+ public boolean isIncludedInRuleNode(
+ NumericAttributeBinaryRulePredicate predicate) {
+ boolean ret;
+ if (this.operator == 1) { // <=
+ ret = (predicate.attValue <= this.attValue);
+ } else { // >
+ ret = (predicate.attValue > this.attValue);
}
- @Override
- public int[] getAttsTestDependsOn() {
- return new int[]{this.attIndex};
- }
+ return ret;
+ }
- public double getSplitValue() {
- return this.attValue;
- }
+ public void setAttributeValue(
+ NumericAttributeBinaryRulePredicate ruleSplitNodeTest) {
+ this.attValue = ruleSplitNodeTest.attValue;
- @Override
- public boolean evaluate(Instance inst) {
- return (branchForInstance(inst) == 0);
- }
-
- @Override
- public String toString() {
- if ((operator >= 0) && (operator <= 2)) {
- String compareChar = (operator == 0) ? "=" : (operator == 1) ? "<=" : ">";
- //int equalsBranch = this.equalsPassesTest ? 0 : 1;
- return "x" + this.attIndex
- + ' '
- + compareChar
- + ' '
- + this.attValue;
- }
- throw new IndexOutOfBoundsException();
- }
-
- public boolean isEqual(NumericAttributeBinaryRulePredicate predicate) {
- return (this.attIndex == predicate.attIndex
- && this.attValue == predicate.attValue
- && this.operator == predicate.operator);
- }
-
- public boolean isUsingSameAttribute(NumericAttributeBinaryRulePredicate predicate) {
- return (this.attIndex == predicate.attIndex
- && this.operator == predicate.operator);
- }
-
- public boolean isIncludedInRuleNode(
- NumericAttributeBinaryRulePredicate predicate) {
- boolean ret;
- if (this.operator == 1) { // <=
- ret = (predicate.attValue <= this.attValue);
- } else { // >
- ret = (predicate.attValue > this.attValue);
- }
-
- return ret;
- }
-
- public void setAttributeValue(
- NumericAttributeBinaryRulePredicate ruleSplitNodeTest) {
- this.attValue = ruleSplitNodeTest.attValue;
-
- }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/splitcriteria/SDRSplitCriterionAMRules.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/splitcriteria/SDRSplitCriterionAMRules.java
index 58d1e2f..818d075 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/splitcriteria/SDRSplitCriterionAMRules.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/splitcriteria/SDRSplitCriterionAMRules.java
@@ -43,63 +43,57 @@
import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SDRSplitCriterion;
import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SplitCriterion;
-
public class SDRSplitCriterionAMRules extends SDRSplitCriterion implements SplitCriterion {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
+ @Override
+ public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) {
+ double SDR = 0.0;
+ double N = preSplitDist[0];
+ int count = 0;
- @Override
- public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) {
- double SDR=0.0;
- double N = preSplitDist[0];
- int count = 0;
+ for (int i = 0; i < postSplitDists.length; i++)
+ {
+ double Ni = postSplitDists[i][0];
+ if (Ni >= 0.05 * preSplitDist[0]) {
+ count = count + 1;
+ }
+ }
+ if (count == postSplitDists.length) {
+ SDR = computeSD(preSplitDist);
- for(int i = 0; i < postSplitDists.length; i++)
- {
- double Ni = postSplitDists[i][0];
- if(Ni >=0.05*preSplitDist[0]){
- count = count +1;
- }
- }
- if(count == postSplitDists.length){
- SDR = computeSD(preSplitDist);
+ for (int i = 0; i < postSplitDists.length; i++)
+ {
+ double Ni = postSplitDists[i][0];
+ SDR -= (Ni / N) * computeSD(postSplitDists[i]);
- for(int i = 0; i < postSplitDists.length; i++)
- {
- double Ni = postSplitDists[i][0];
- SDR -= (Ni/N)*computeSD(postSplitDists[i]);
+ }
+ }
+ return SDR;
+ }
- }
- }
- return SDR;
- }
+ @Override
+ public double getRangeOfMerit(double[] preSplitDist) {
+ return 1;
+ }
+ public static double[] computeBranchSplitMerits(double[][] postSplitDists) {
+ double[] SDR = new double[postSplitDists.length];
+ double N = 0;
+ for (int i = 0; i < postSplitDists.length; i++)
+ {
+ double Ni = postSplitDists[i][0];
+ N += Ni;
+ }
+ for (int i = 0; i < postSplitDists.length; i++)
+ {
+ double Ni = postSplitDists[i][0];
+ SDR[i] = (Ni / N) * computeSD(postSplitDists[i]);
+ }
+ return SDR;
- @Override
- public double getRangeOfMerit(double[] preSplitDist) {
- return 1;
- }
-
- public static double[] computeBranchSplitMerits(double[][] postSplitDists) {
- double[] SDR = new double[postSplitDists.length];
- double N = 0;
-
- for(int i = 0; i < postSplitDists.length; i++)
- {
- double Ni = postSplitDists[i][0];
- N += Ni;
- }
- for(int i = 0; i < postSplitDists.length; i++)
- {
- double Ni = postSplitDists[i][0];
- SDR[i] = (Ni/N)*computeSD(postSplitDists[i]);
- }
- return SDR;
-
- }
-
+ }
}
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/AbstractErrorWeightedVote.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/AbstractErrorWeightedVote.java
index 4e93b03..dcd975d 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/AbstractErrorWeightedVote.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/AbstractErrorWeightedVote.java
@@ -26,78 +26,77 @@
import com.yahoo.labs.samoa.moa.AbstractMOAObject;
/**
- * AbstractErrorWeightedVote class for weighted votes based on estimates of errors.
- *
+ * AbstractErrorWeightedVote class for weighted votes based on estimates of
+ * errors.
+ *
* @author Joao Duarte (jmduarte@inescporto.pt)
* @version $Revision: 1 $
*/
-public abstract class AbstractErrorWeightedVote extends AbstractMOAObject implements ErrorWeightedVote{
- /**
+public abstract class AbstractErrorWeightedVote extends AbstractMOAObject implements ErrorWeightedVote {
+ /**
*
*/
- private static final long serialVersionUID = -7340491298217227675L;
- protected List<double[]> votes;
- protected List<Double> errors;
- protected double[] weights;
+ private static final long serialVersionUID = -7340491298217227675L;
+ protected List<double[]> votes;
+ protected List<Double> errors;
+ protected double[] weights;
+ public AbstractErrorWeightedVote() {
+ super();
+ votes = new ArrayList<double[]>();
+ errors = new ArrayList<Double>();
+ }
+ public AbstractErrorWeightedVote(AbstractErrorWeightedVote aewv) {
+ super();
+ votes = new ArrayList<double[]>();
+ for (double[] vote : aewv.votes) {
+ double[] v = new double[vote.length];
+ for (int i = 0; i < vote.length; i++)
+ v[i] = vote[i];
+ votes.add(v);
+ }
+ errors = new ArrayList<Double>();
+ for (Double db : aewv.errors) {
+ errors.add(db.doubleValue());
+ }
+ if (aewv.weights != null) {
+ weights = new double[aewv.weights.length];
+ for (int i = 0; i < aewv.weights.length; i++)
+ weights[i] = aewv.weights[i];
+ }
+ }
- public AbstractErrorWeightedVote() {
- super();
- votes = new ArrayList<double[]>();
- errors = new ArrayList<Double>();
- }
-
- public AbstractErrorWeightedVote(AbstractErrorWeightedVote aewv) {
- super();
- votes = new ArrayList<double[]>();
- for (double[] vote:aewv.votes) {
- double[] v = new double[vote.length];
- for (int i=0; i<vote.length; i++) v[i] = vote[i];
- votes.add(v);
- }
- errors = new ArrayList<Double>();
- for (Double db:aewv.errors) {
- errors.add(db.doubleValue());
- }
- if (aewv.weights != null) {
- weights = new double[aewv.weights.length];
- for (int i = 0; i<aewv.weights.length; i++)
- weights[i] = aewv.weights[i];
- }
- }
+ @Override
+ public void addVote(double[] vote, double error) {
+ votes.add(vote);
+ errors.add(error);
+ }
+ @Override
+ abstract public double[] computeWeightedVote();
- @Override
- public void addVote(double [] vote, double error) {
- votes.add(vote);
- errors.add(error);
- }
+ @Override
+ public double getWeightedError()
+ {
+ double weightedError = 0;
+ if (weights != null && weights.length == errors.size())
+ {
+ for (int i = 0; i < weights.length; ++i)
+ weightedError += errors.get(i) * weights[i];
+ }
+ else
+ weightedError = -1;
+ return weightedError;
+ }
- @Override
- abstract public double[] computeWeightedVote();
+ @Override
+ public double[] getWeights() {
+ return weights;
+ }
- @Override
- public double getWeightedError()
- {
- double weightedError=0;
- if (weights!=null && weights.length==errors.size())
- {
- for (int i=0; i<weights.length; ++i)
- weightedError+=errors.get(i)*weights[i];
- }
- else
- weightedError=-1;
- return weightedError;
- }
-
- @Override
- public double [] getWeights() {
- return weights;
- }
-
- @Override
- public int getNumberVotes() {
- return votes.size();
- }
+ @Override
+ public int getNumberVotes() {
+ return votes.size();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/ErrorWeightedVote.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/ErrorWeightedVote.java
index 943dd9d..fca7115 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/ErrorWeightedVote.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/ErrorWeightedVote.java
@@ -23,59 +23,60 @@
import com.yahoo.labs.samoa.moa.MOAObject;
/**
- * ErrorWeightedVote interface for weighted votes based on estimates of errors.
- *
+ * ErrorWeightedVote interface for weighted votes based on estimates of errors.
+ *
* @author Joao Duarte (jmduarte@inescporto.pt)
* @version $Revision: 1 $
*/
public interface ErrorWeightedVote {
-
- /**
- * Adds a vote and the corresponding error for the computation of the weighted vote and respective weighted error.
- *
- * @param vote a vote returned by a classifier
- * @param error the error associated to the vote
- */
- public void addVote(double [] vote, double error);
-
- /**
- * Computes the weighted vote.
- * Also updates the weights of the votes.
- *
- * @return the weighted vote
- */
- public double [] computeWeightedVote();
-
- /**
- * Returns the weighted error.
- *
- * @pre computeWeightedVote()
- * @return the weighted error
- */
- public double getWeightedError();
-
- /**
- * Return the weights error.
- *
- * @pre computeWeightedVote()
- * @return the weights
- */
- public double [] getWeights();
-
-
- /**
- * The number of votes added so far.
- *
- * @return the number of votes
- */
- public int getNumberVotes();
-
- /**
- * Creates a copy of the object
- *
- * @return copy of the object
- */
- public MOAObject copy();
-
- public ErrorWeightedVote getACopy();
+
+ /**
+ * Adds a vote and the corresponding error for the computation of the weighted
+ * vote and respective weighted error.
+ *
+ * @param vote
+ * a vote returned by a classifier
+ * @param error
+ * the error associated to the vote
+ */
+ public void addVote(double[] vote, double error);
+
+ /**
+ * Computes the weighted vote. Also updates the weights of the votes.
+ *
+ * @return the weighted vote
+ */
+ public double[] computeWeightedVote();
+
+ /**
+ * Returns the weighted error.
+ *
+ * @pre computeWeightedVote()
+ * @return the weighted error
+ */
+ public double getWeightedError();
+
+ /**
+ * Return the weights error.
+ *
+ * @pre computeWeightedVote()
+ * @return the weights
+ */
+ public double[] getWeights();
+
+ /**
+ * The number of votes added so far.
+ *
+ * @return the number of votes
+ */
+ public int getNumberVotes();
+
+ /**
+ * Creates a copy of the object
+ *
+ * @return copy of the object
+ */
+ public MOAObject copy();
+
+ public ErrorWeightedVote getACopy();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/InverseErrorWeightedVote.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/InverseErrorWeightedVote.java
index 401dc58..2e2cb56 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/InverseErrorWeightedVote.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/InverseErrorWeightedVote.java
@@ -21,79 +21,81 @@
*/
/**
- * InverseErrorWeightedVote class for weighted votes based on estimates of errors.
- *
+ * InverseErrorWeightedVote class for weighted votes based on estimates of
+ * errors.
+ *
* @author Joao Duarte (jmduarte@inescporto.pt)
* @version $Revision: 1 $
*/
public class InverseErrorWeightedVote extends AbstractErrorWeightedVote {
- /**
+ /**
*
*/
- private static final double EPS = 0.000000001; //just to prevent divide by 0 in 1/X -> 1/(x+EPS)
- private static final long serialVersionUID = 6359349250620616482L;
+ private static final double EPS = 0.000000001; // just to prevent divide by 0
+ // in 1/X -> 1/(x+EPS)
+ private static final long serialVersionUID = 6359349250620616482L;
- public InverseErrorWeightedVote() {
- super();
- }
-
- public InverseErrorWeightedVote(AbstractErrorWeightedVote aewv) {
- super(aewv);
- }
-
- @Override
- public double[] computeWeightedVote() {
- int n=votes.size();
- weights=new double[n];
- double [] weightedVote=null;
- if (n>0){
- int d=votes.get(0).length;
- weightedVote=new double[d];
- double sumError=0;
- //weights are 1/(error+eps)
- for (int i=0; i<n; ++i){
- if(errors.get(i)<Double.MAX_VALUE){
- weights[i]=1.0/(errors.get(i)+EPS);
- sumError+=weights[i];
- }
- else
- weights[i]=0;
+ public InverseErrorWeightedVote() {
+ super();
+ }
- }
+ public InverseErrorWeightedVote(AbstractErrorWeightedVote aewv) {
+ super(aewv);
+ }
- if(sumError>0)
- for (int i=0; i<n; ++i)
- {
- //normalize so that weights sum 1
- weights[i]/=sumError;
- //compute weighted vote
- for(int j=0; j<d; j++)
- weightedVote[j]+=votes.get(i)[j]*weights[i];
- }
- //Only occurs if all errors=Double.MAX_VALUE
- else
- {
- //compute arithmetic vote
- for (int i=0; i<n; ++i)
- {
- for(int j=0; j<d; j++)
- weightedVote[j]+=votes.get(i)[j]/n;
- }
- }
- }
- return weightedVote;
- }
+ @Override
+ public double[] computeWeightedVote() {
+ int n = votes.size();
+ weights = new double[n];
+ double[] weightedVote = null;
+ if (n > 0) {
+ int d = votes.get(0).length;
+ weightedVote = new double[d];
+ double sumError = 0;
+ // weights are 1/(error+eps)
+ for (int i = 0; i < n; ++i) {
+ if (errors.get(i) < Double.MAX_VALUE) {
+ weights[i] = 1.0 / (errors.get(i) + EPS);
+ sumError += weights[i];
+ }
+ else
+ weights[i] = 0;
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
+ }
- }
-
- @Override
- public InverseErrorWeightedVote getACopy() {
- return new InverseErrorWeightedVote(this);
- }
+ if (sumError > 0)
+ for (int i = 0; i < n; ++i)
+ {
+ // normalize so that weights sum 1
+ weights[i] /= sumError;
+ // compute weighted vote
+ for (int j = 0; j < d; j++)
+ weightedVote[j] += votes.get(i)[j] * weights[i];
+ }
+ // Only occurs if all errors=Double.MAX_VALUE
+ else
+ {
+ // compute arithmetic vote
+ for (int i = 0; i < n; ++i)
+ {
+ for (int j = 0; j < d; j++)
+ weightedVote[j] += votes.get(i)[j] / n;
+ }
+ }
+ }
+ return weightedVote;
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+
+ }
+
+ @Override
+ public InverseErrorWeightedVote getACopy() {
+ return new InverseErrorWeightedVote(this);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/UniformWeightedVote.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/UniformWeightedVote.java
index ce7a74f..61c5b12 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/UniformWeightedVote.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/UniformWeightedVote.java
@@ -20,54 +20,52 @@
* #L%
*/
-
/**
- * UniformWeightedVote class for weighted votes based on estimates of errors.
- *
+ * UniformWeightedVote class for weighted votes based on estimates of errors.
+ *
* @author Joao Duarte (jmduarte@inescporto.pt)
* @version $Revision: 1 $
*/
public class UniformWeightedVote extends AbstractErrorWeightedVote {
+ private static final long serialVersionUID = 6359349250620616482L;
- private static final long serialVersionUID = 6359349250620616482L;
+ public UniformWeightedVote() {
+ super();
+ }
- public UniformWeightedVote() {
- super();
- }
-
- public UniformWeightedVote(AbstractErrorWeightedVote aewv) {
- super(aewv);
- }
-
- @Override
- public double[] computeWeightedVote() {
- int n=votes.size();
- weights=new double[n];
- double [] weightedVote=null;
- if (n>0){
- int d=votes.get(0).length;
- weightedVote=new double[d];
- for (int i=0; i<n; i++)
- {
- weights[i]=1.0/n;
- for(int j=0; j<d; j++)
- weightedVote[j]+=(votes.get(i)[j]*weights[i]);
- }
+ public UniformWeightedVote(AbstractErrorWeightedVote aewv) {
+ super(aewv);
+ }
- }
- return weightedVote;
- }
+ @Override
+ public double[] computeWeightedVote() {
+ int n = votes.size();
+ weights = new double[n];
+ double[] weightedVote = null;
+ if (n > 0) {
+ int d = votes.get(0).length;
+ weightedVote = new double[d];
+ for (int i = 0; i < n; i++)
+ {
+ weights[i] = 1.0 / n;
+ for (int j = 0; j < d; j++)
+ weightedVote[j] += (votes.get(i)[j] * weights[i]);
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
+ }
+ return weightedVote;
+ }
- }
-
- @Override
- public UniformWeightedVote getACopy() {
- return new UniformWeightedVote(this);
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+
+ }
+
+ @Override
+ public UniformWeightedVote getACopy() {
+ return new UniformWeightedVote(this);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyFading.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyFading.java
index 133c755..d18134b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyFading.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyFading.java
@@ -22,62 +22,67 @@
/**
* Page-Hinkley Test with more weight for recent instances.
- *
+ *
*/
public class PageHinkleyFading extends PageHinkleyTest {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 7110953184708812339L;
- private double fadingFactor=0.99;
+ private static final long serialVersionUID = 7110953184708812339L;
+ private double fadingFactor = 0.99;
- public PageHinkleyFading() {
- super();
- }
-
- public PageHinkleyFading(double threshold, double alpha) {
- super(threshold, alpha);
- }
- protected double instancesSeen;
+ public PageHinkleyFading() {
+ super();
+ }
- @Override
- public void reset() {
+ public PageHinkleyFading(double threshold, double alpha) {
+ super(threshold, alpha);
+ }
- super.reset();
- this.instancesSeen=0;
+ protected double instancesSeen;
- }
+ @Override
+ public void reset() {
- @Override
- public boolean update(double error) {
- this.instancesSeen=1+fadingFactor*this.instancesSeen;
- double absolutError = Math.abs(error);
+ super.reset();
+ this.instancesSeen = 0;
- this.sumAbsolutError = fadingFactor*this.sumAbsolutError + absolutError;
- if (this.instancesSeen > 30) {
- double mT = absolutError - (this.sumAbsolutError / this.instancesSeen) - this.alpha;
- this.cumulativeSum = this.cumulativeSum + mT; // Update the cumulative mT sum
- if (this.cumulativeSum < this.minimumValue) { // Update the minimum mT value if the new mT is smaller than the current minimum
- this.minimumValue = this.cumulativeSum;
- }
- return (((this.cumulativeSum - this.minimumValue) > this.threshold));
- }
- return false;
- }
-
- @Override
- public PageHinkleyTest getACopy() {
- PageHinkleyFading newTest = new PageHinkleyFading(this.threshold, this.alpha);
- this.copyFields(newTest);
- return newTest;
- }
-
- @Override
- protected void copyFields(PageHinkleyTest newTest) {
- super.copyFields(newTest);
- PageHinkleyFading newFading = (PageHinkleyFading) newTest;
- newFading.fadingFactor = this.fadingFactor;
- newFading.instancesSeen = this.instancesSeen;
- }
+ }
+
+ @Override
+ public boolean update(double error) {
+ this.instancesSeen = 1 + fadingFactor * this.instancesSeen;
+ double absolutError = Math.abs(error);
+
+ this.sumAbsolutError = fadingFactor * this.sumAbsolutError + absolutError;
+ if (this.instancesSeen > 30) {
+ double mT = absolutError - (this.sumAbsolutError / this.instancesSeen) - this.alpha;
+ this.cumulativeSum = this.cumulativeSum + mT; // Update the cumulative mT
+ // sum
+ if (this.cumulativeSum < this.minimumValue) { // Update the minimum mT
+ // value if the new mT is
+ // smaller than the current
+ // minimum
+ this.minimumValue = this.cumulativeSum;
+ }
+ return (((this.cumulativeSum - this.minimumValue) > this.threshold));
+ }
+ return false;
+ }
+
+ @Override
+ public PageHinkleyTest getACopy() {
+ PageHinkleyFading newTest = new PageHinkleyFading(this.threshold, this.alpha);
+ this.copyFields(newTest);
+ return newTest;
+ }
+
+ @Override
+ protected void copyFields(PageHinkleyTest newTest) {
+ super.copyFields(newTest);
+ PageHinkleyFading newFading = (PageHinkleyFading) newTest;
+ newFading.fadingFactor = this.fadingFactor;
+ newFading.instancesSeen = this.instancesSeen;
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyTest.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyTest.java
index c313224..354b9a8 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyTest.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyTest.java
@@ -24,73 +24,76 @@
/**
* Page-Hinkley Test with equal weights for all instances.
- *
+ *
*/
public class PageHinkleyTest implements Serializable {
- private static final long serialVersionUID = 1L;
- protected double cumulativeSum;
+ private static final long serialVersionUID = 1L;
+ protected double cumulativeSum;
- public double getCumulativeSum() {
- return cumulativeSum;
- }
+ public double getCumulativeSum() {
+ return cumulativeSum;
+ }
- public double getMinimumValue() {
- return minimumValue;
- }
+ public double getMinimumValue() {
+ return minimumValue;
+ }
+ protected double minimumValue;
+ protected double sumAbsolutError;
+ protected long phinstancesSeen;
+ protected double threshold;
+ protected double alpha;
- protected double minimumValue;
- protected double sumAbsolutError;
- protected long phinstancesSeen;
- protected double threshold;
- protected double alpha;
+ public PageHinkleyTest() {
+ this(0, 0);
+ }
- public PageHinkleyTest() {
- this(0,0);
- }
-
- public PageHinkleyTest(double threshold, double alpha) {
- this.threshold = threshold;
- this.alpha = alpha;
- this.reset();
- }
+ public PageHinkleyTest(double threshold, double alpha) {
+ this.threshold = threshold;
+ this.alpha = alpha;
+ this.reset();
+ }
- public void reset() {
- this.cumulativeSum = 0.0;
- this.minimumValue = Double.MAX_VALUE;
- this.sumAbsolutError = 0.0;
- this.phinstancesSeen = 0;
- }
+ public void reset() {
+ this.cumulativeSum = 0.0;
+ this.minimumValue = Double.MAX_VALUE;
+ this.sumAbsolutError = 0.0;
+ this.phinstancesSeen = 0;
+ }
- //Compute Page-Hinkley test
- public boolean update(double error) {
+ // Compute Page-Hinkley test
+ public boolean update(double error) {
- this.phinstancesSeen++;
- double absolutError = Math.abs(error);
- this.sumAbsolutError = this.sumAbsolutError + absolutError;
- if (this.phinstancesSeen > 30) {
- double mT = absolutError - (this.sumAbsolutError / this.phinstancesSeen) - this.alpha;
- this.cumulativeSum = this.cumulativeSum + mT; // Update the cumulative mT sum
- if (this.cumulativeSum < this.minimumValue) { // Update the minimum mT value if the new mT is smaller than the current minimum
- this.minimumValue = this.cumulativeSum;
- }
- return (((this.cumulativeSum - this.minimumValue) > this.threshold));
- }
- return false;
- }
-
- public PageHinkleyTest getACopy() {
- PageHinkleyTest newTest = new PageHinkleyTest(this.threshold, this.alpha);
- this.copyFields(newTest);
- return newTest;
- }
-
- protected void copyFields(PageHinkleyTest newTest) {
- newTest.cumulativeSum = this.cumulativeSum;
- newTest.minimumValue = this.minimumValue;
- newTest.sumAbsolutError = this.sumAbsolutError;
- newTest.phinstancesSeen = this.phinstancesSeen;
- }
+ this.phinstancesSeen++;
+ double absolutError = Math.abs(error);
+ this.sumAbsolutError = this.sumAbsolutError + absolutError;
+ if (this.phinstancesSeen > 30) {
+ double mT = absolutError - (this.sumAbsolutError / this.phinstancesSeen) - this.alpha;
+ this.cumulativeSum = this.cumulativeSum + mT; // Update the cumulative mT
+ // sum
+ if (this.cumulativeSum < this.minimumValue) { // Update the minimum mT
+ // value if the new mT is
+ // smaller than the current
+ // minimum
+ this.minimumValue = this.cumulativeSum;
+ }
+ return (((this.cumulativeSum - this.minimumValue) > this.threshold));
+ }
+ return false;
+ }
+
+ public PageHinkleyTest getACopy() {
+ PageHinkleyTest newTest = new PageHinkleyTest(this.threshold, this.alpha);
+ this.copyFields(newTest);
+ return newTest;
+ }
+
+ protected void copyFields(PageHinkleyTest newTest) {
+ newTest.cumulativeSum = this.cumulativeSum;
+ newTest.minimumValue = this.minimumValue;
+ newTest.sumAbsolutError = this.sumAbsolutError;
+ newTest.phinstancesSeen = this.phinstancesSeen;
+ }
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/CFCluster.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/CFCluster.java
index 47e1379..b8f22c3 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/CFCluster.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/CFCluster.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.cluster;
/*
@@ -26,148 +25,153 @@
/* micro cluster, as defined by Aggarwal et al, On Clustering Massive Data Streams: A Summarization Praradigm
* in the book Data streams : models and algorithms, by Charu C Aggarwal
* @article{
- title = {Data Streams: Models and Algorithms},
- author = {Aggarwal, Charu C.},
- year = {2007},
- publisher = {Springer Science+Business Media, LLC},
- url = {http://ebooks.ulb.tu-darmstadt.de/11157/},
- institution = {eBooks [http://ebooks.ulb.tu-darmstadt.de/perl/oai2] (Germany)},
-}
+ title = {Data Streams: Models and Algorithms},
+ author = {Aggarwal, Charu C.},
+ year = {2007},
+ publisher = {Springer Science+Business Media, LLC},
+ url = {http://ebooks.ulb.tu-darmstadt.de/11157/},
+ institution = {eBooks [http://ebooks.ulb.tu-darmstadt.de/perl/oai2] (Germany)},
+ }
-DEFINITION A micro-clusterfor a set of d-dimensionalpoints Xi,. .Xi,
-with t i m e s t a m p s ~. . .T,, is the (2-d+3)tuple (CF2", CFlX CF2t, CFlt, n),
-wherein CF2" and CFlX each correspond to a vector of d entries. The definition of each of these entries is as follows:
+ DEFINITION A micro-clusterfor a set of d-dimensionalpoints Xi,. .Xi,
+ with t i m e s t a m p s ~. . .T,, is the (2-d+3)tuple (CF2", CFlX CF2t, CFlt, n),
+ wherein CF2" and CFlX each correspond to a vector of d entries. The definition of each of these entries is as follows:
-o For each dimension, the sum of the squares of the data values is maintained
-in CF2". Thus, CF2" contains d values. The p-th entry of CF2" is equal to
-\sum_j=1^n(x_i_j)^2
+ o For each dimension, the sum of the squares of the data values is maintained
+ in CF2". Thus, CF2" contains d values. The p-th entry of CF2" is equal to
+ \sum_j=1^n(x_i_j)^2
-o For each dimension, the sum of the data values is maintained in C F l X .
-Thus, CFIX contains d values. The p-th entry of CFIX is equal to
-\sum_j=1^n x_i_j
+ o For each dimension, the sum of the data values is maintained in C F l X .
+ Thus, CFIX contains d values. The p-th entry of CFIX is equal to
+ \sum_j=1^n x_i_j
-o The sum of the squares of the time stamps Ti,. .Tin maintained in CF2t
+ o The sum of the squares of the time stamps Ti,. .Tin maintained in CF2t
-o The sum of the time stamps Ti, . . .Tin maintained in CFlt.
+ o The sum of the time stamps Ti, . . .Tin maintained in CFlt.
-o The number of data points is maintained in n.
+ o The number of data points is maintained in n.
*/
public abstract class CFCluster extends SphereCluster {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected double radiusFactor = 1.8;
+ protected double radiusFactor = 1.8;
- /**
- * Number of points in the cluster.
- */
- protected double N;
- /**
- * Linear sum of all the points added to the cluster.
- */
- public double[] LS;
- /**
- * Squared sum of all the points added to the cluster.
- */
- public double[] SS;
+ /**
+ * Number of points in the cluster.
+ */
+ protected double N;
+ /**
+ * Linear sum of all the points added to the cluster.
+ */
+ public double[] LS;
+ /**
+ * Squared sum of all the points added to the cluster.
+ */
+ public double[] SS;
- /**
- * Instantiates an empty kernel with the given dimensionality.
- * @param dimensions The number of dimensions of the points that can be in
- * this kernel.
- */
- public CFCluster(Instance instance, int dimensions) {
- this(instance.toDoubleArray(), dimensions);
- }
+ /**
+ * Instantiates an empty kernel with the given dimensionality.
+ *
+ * @param dimensions
+ * The number of dimensions of the points that can be in this kernel.
+ */
+ public CFCluster(Instance instance, int dimensions) {
+ this(instance.toDoubleArray(), dimensions);
+ }
- protected CFCluster(int dimensions) {
- this.N = 0;
- this.LS = new double[dimensions];
- this.SS = new double[dimensions];
- Arrays.fill(this.LS, 0.0);
- Arrays.fill(this.SS, 0.0);
- }
+ protected CFCluster(int dimensions) {
+ this.N = 0;
+ this.LS = new double[dimensions];
+ this.SS = new double[dimensions];
+ Arrays.fill(this.LS, 0.0);
+ Arrays.fill(this.SS, 0.0);
+ }
- public CFCluster(double [] center, int dimensions) {
- this.N = 1;
- this.LS = center;
- this.SS = new double[dimensions];
- for (int i = 0; i < SS.length; i++) {
- SS[i]=Math.pow(center[i], 2);
- }
- }
+ public CFCluster(double[] center, int dimensions) {
+ this.N = 1;
+ this.LS = center;
+ this.SS = new double[dimensions];
+ for (int i = 0; i < SS.length; i++) {
+ SS[i] = Math.pow(center[i], 2);
+ }
+ }
- public CFCluster(CFCluster cluster) {
- this.N = cluster.N;
- this.LS = Arrays.copyOf(cluster.LS, cluster.LS.length);
- this.SS = Arrays.copyOf(cluster.SS, cluster.SS.length);
- }
+ public CFCluster(CFCluster cluster) {
+ this.N = cluster.N;
+ this.LS = Arrays.copyOf(cluster.LS, cluster.LS.length);
+ this.SS = Arrays.copyOf(cluster.SS, cluster.SS.length);
+ }
- public void add(CFCluster cluster ) {
- this.N += cluster.N;
- addVectors( this.LS, cluster.LS );
- addVectors( this.SS, cluster.SS );
- }
+ public void add(CFCluster cluster) {
+ this.N += cluster.N;
+ addVectors(this.LS, cluster.LS);
+ addVectors(this.SS, cluster.SS);
+ }
- public abstract CFCluster getCF();
+ public abstract CFCluster getCF();
- /**
- * @return this kernels' center
- */
- @Override
- public double[] getCenter() {
- assert (this.N>0);
- double res[] = new double[this.LS.length];
- for ( int i = 0; i < res.length; i++ ) {
- res[i] = this.LS[i] / N;
- }
- return res;
- }
+ /**
+ * @return this kernels' center
+ */
+ @Override
+ public double[] getCenter() {
+ assert (this.N > 0);
+ double res[] = new double[this.LS.length];
+ for (int i = 0; i < res.length; i++) {
+ res[i] = this.LS[i] / N;
+ }
+ return res;
+ }
+ @Override
+ public abstract double getInclusionProbability(Instance instance);
- @Override
- public abstract double getInclusionProbability(Instance instance);
+ /**
+ * See interface <code>Cluster</code>
+ *
+ * @return The radius of the cluster.
+ */
+ @Override
+ public abstract double getRadius();
- /**
- * See interface <code>Cluster</code>
- * @return The radius of the cluster.
- */
- @Override
- public abstract double getRadius();
+ /**
+ * See interface <code>Cluster</code>
+ *
+ * @return The weight.
+ * @see Cluster#getWeight()
+ */
+ @Override
+ public double getWeight() {
+ return N;
+ }
- /**
- * See interface <code>Cluster</code>
- * @return The weight.
- * @see Cluster#getWeight()
- */
- @Override
- public double getWeight() {
- return N;
- }
+ public void setN(double N) {
+ this.N = N;
+ }
- public void setN(double N){
- this.N = N;
- }
+ public double getN() {
+ return N;
+ }
- public double getN() {
- return N;
- }
+ /**
+ * Adds the second array to the first array element by element. The arrays
+ * must have the same length.
+ *
+ * @param a1
+ * Vector to which the second vector is added.
+ * @param a2
+ * Vector to be added. This vector does not change.
+ */
+ public static void addVectors(double[] a1, double[] a2) {
+ assert (a1 != null);
+ assert (a2 != null);
+ assert (a1.length == a2.length) : "Adding two arrays of different "
+ + "length";
- /**
- * Adds the second array to the first array element by element. The arrays
- * must have the same length.
- * @param a1 Vector to which the second vector is added.
- * @param a2 Vector to be added. This vector does not change.
- */
- public static void addVectors(double[] a1, double[] a2) {
- assert (a1 != null);
- assert (a2 != null);
- assert (a1.length == a2.length) : "Adding two arrays of different "
- + "length";
-
- for (int i = 0; i < a1.length; i++) {
- a1[i] += a2[i];
- }
- }
+ for (int i = 0; i < a1.length; i++) {
+ a1[i] += a2[i];
+ }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Cluster.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Cluster.java
index a9380a8..82cdb3b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Cluster.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Cluster.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.cluster;
/*
@@ -31,142 +30,139 @@
public abstract class Cluster extends AbstractMOAObject {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- private double id = -1;
- private double gtLabel = -1;
+ private double id = -1;
+ private double gtLabel = -1;
- private Map<String, String> measure_values;
+ private Map<String, String> measure_values;
+ public Cluster() {
+ this.measure_values = new HashMap<>();
+ }
- public Cluster(){
- this.measure_values = new HashMap<>();
+ /**
+ * @return the center of the cluster
+ */
+ public abstract double[] getCenter();
+
+ /**
+ * Returns the weight of this cluster, not neccessarily normalized. It could,
+ * for instance, simply return the number of points contined in this cluster.
+ *
+ * @return the weight
+ */
+ public abstract double getWeight();
+
+ /**
+ * Returns the probability of the given point belonging to this cluster.
+ *
+ * @param instance
+ * @return a value between 0 and 1
+ */
+ public abstract double getInclusionProbability(Instance instance);
+
+ // TODO: for non sphere cluster sample points, find out MIN MAX neighbours
+ // within cluster
+ // and return the relative distance
+ // public abstract double getRelativeHullDistance(Instance instance);
+
+ @Override
+ public void getDescription(StringBuilder sb, int i) {
+ sb.append("Cluster Object");
+ }
+
+ public void setId(double id) {
+ this.id = id;
+ }
+
+ public double getId() {
+ return id;
+ }
+
+ public boolean isGroundTruth() {
+ return gtLabel != -1;
+ }
+
+ public void setGroundTruth(double truth) {
+ gtLabel = truth;
+ }
+
+ public double getGroundTruth() {
+ return gtLabel;
+ }
+
+ /**
+ * Samples this cluster by returning a point from inside it.
+ *
+ * @param random
+ * a random number source
+ * @return an Instance that lies inside this cluster
+ */
+ public abstract Instance sample(Random random);
+
+ public void setMeasureValue(String measureKey, String value) {
+ measure_values.put(measureKey, value);
+ }
+
+ public void setMeasureValue(String measureKey, double value) {
+ measure_values.put(measureKey, Double.toString(value));
+ }
+
+ public String getMeasureValue(String measureKey) {
+ if (measure_values.containsKey(measureKey))
+ return measure_values.get(measureKey);
+ else
+ return "";
+ }
+
+ protected void getClusterSpecificInfo(List<String> infoTitle, List<String> infoValue) {
+ infoTitle.add("ClusterID");
+ infoValue.add(Integer.toString((int) getId()));
+
+ infoTitle.add("Type");
+ infoValue.add(getClass().getSimpleName());
+
+ double c[] = getCenter();
+ if (c != null)
+ for (int i = 0; i < c.length; i++) {
+ infoTitle.add("Dim" + i);
+ infoValue.add(Double.toString(c[i]));
+ }
+
+ infoTitle.add("Weight");
+ infoValue.add(Double.toString(getWeight()));
+
+ }
+
+ public String getInfo() {
+ List<String> infoTitle = new ArrayList<>();
+ List<String> infoValue = new ArrayList<>();
+ getClusterSpecificInfo(infoTitle, infoValue);
+
+ StringBuilder sb = new StringBuilder();
+
+ // Cluster properties
+ sb.append("<html>");
+ sb.append("<table>");
+ int i = 0;
+ while (i < infoTitle.size() && i < infoValue.size()) {
+ sb.append("<tr><td>" + infoTitle.get(i) + "</td><td>" + infoValue.get(i) + "</td></tr>");
+ i++;
}
- /**
- * @return the center of the cluster
- */
- public abstract double[] getCenter();
+ sb.append("</table>");
- /**
- * Returns the weight of this cluster, not neccessarily normalized.
- * It could, for instance, simply return the number of points contined
- * in this cluster.
- * @return the weight
- */
- public abstract double getWeight();
-
- /**
- * Returns the probability of the given point belonging to
- * this cluster.
- *
- * @param instance
- * @return a value between 0 and 1
- */
- public abstract double getInclusionProbability(Instance instance);
-
-
- //TODO: for non sphere cluster sample points, find out MIN MAX neighbours within cluster
- //and return the relative distance
- //public abstract double getRelativeHullDistance(Instance instance);
-
- @Override
- public void getDescription(StringBuilder sb, int i) {
- sb.append("Cluster Object");
+ // Evaluation info
+ sb.append("<br>");
+ sb.append("<b>Evaluation</b><br>");
+ sb.append("<table>");
+ for (Object o : measure_values.entrySet()) {
+ Map.Entry e = (Map.Entry) o;
+ sb.append("<tr><td>" + e.getKey() + "</td><td>" + e.getValue() + "</td></tr>");
}
-
- public void setId(double id) {
- this.id = id;
- }
-
- public double getId() {
- return id;
- }
-
- public boolean isGroundTruth(){
- return gtLabel != -1;
- }
-
- public void setGroundTruth(double truth){
- gtLabel = truth;
- }
-
- public double getGroundTruth(){
- return gtLabel;
- }
-
-
- /**
- * Samples this cluster by returning a point from inside it.
- * @param random a random number source
- * @return an Instance that lies inside this cluster
- */
- public abstract Instance sample(Random random);
-
-
- public void setMeasureValue(String measureKey, String value){
- measure_values.put(measureKey, value);
- }
-
- public void setMeasureValue(String measureKey, double value){
- measure_values.put(measureKey, Double.toString(value));
- }
-
-
- public String getMeasureValue(String measureKey){
- if(measure_values.containsKey(measureKey))
- return measure_values.get(measureKey);
- else
- return "";
- }
-
-
- protected void getClusterSpecificInfo(List<String> infoTitle, List<String> infoValue){
- infoTitle.add("ClusterID");
- infoValue.add(Integer.toString((int)getId()));
-
- infoTitle.add("Type");
- infoValue.add(getClass().getSimpleName());
-
- double c[] = getCenter();
- if(c!=null)
- for (int i = 0; i < c.length; i++) {
- infoTitle.add("Dim"+i);
- infoValue.add(Double.toString(c[i]));
- }
-
- infoTitle.add("Weight");
- infoValue.add(Double.toString(getWeight()));
-
- }
-
- public String getInfo() {
- List<String> infoTitle = new ArrayList<>();
- List<String> infoValue = new ArrayList<>();
- getClusterSpecificInfo(infoTitle, infoValue);
-
- StringBuilder sb = new StringBuilder();
-
- //Cluster properties
- sb.append("<html>");
- sb.append("<table>");
- int i = 0;
- while(i < infoTitle.size() && i < infoValue.size()){
- sb.append("<tr><td>"+infoTitle.get(i)+"</td><td>"+infoValue.get(i)+"</td></tr>");
- i++;
- }
- sb.append("</table>");
-
- //Evaluation info
- sb.append("<br>");
- sb.append("<b>Evaluation</b><br>");
- sb.append("<table>");
- for (Object o : measure_values.entrySet()) {
- Map.Entry e = (Map.Entry) o;
- sb.append("<tr><td>" + e.getKey() + "</td><td>" + e.getValue() + "</td></tr>");
- }
- sb.append("</table>");
- sb.append("</html>");
- return sb.toString();
- }
+ sb.append("</table>");
+ sb.append("</html>");
+ return sb.toString();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Clustering.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Clustering.java
index 5f76ba5..b7c9862 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Clustering.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Clustering.java
@@ -29,248 +29,242 @@
import com.yahoo.labs.samoa.instances.Attribute;
import com.yahoo.labs.samoa.instances.Instance;
-public class Clustering extends AbstractMOAObject{
+public class Clustering extends AbstractMOAObject {
- private AutoExpandVector<Cluster> clusters;
+ private AutoExpandVector<Cluster> clusters;
- public Clustering() {
- this.clusters = new AutoExpandVector<Cluster>();
+ public Clustering() {
+ this.clusters = new AutoExpandVector<Cluster>();
+ }
+
+ public Clustering(Cluster[] clusters) {
+ this.clusters = new AutoExpandVector<Cluster>();
+ for (int i = 0; i < clusters.length; i++) {
+ this.clusters.add(clusters[i]);
+ }
+ }
+
+ public Clustering(List<? extends Instance> points) {
+ HashMap<Integer, Integer> labelMap = classValues(points);
+ int dim = points.get(0).dataset().numAttributes() - 1;
+
+ int numClasses = labelMap.size();
+ int noiseLabel;
+
+ Attribute classLabel = points.get(0).dataset().classAttribute();
+ int lastLabelIndex = classLabel.numValues() - 1;
+ if (classLabel.value(lastLabelIndex) == "noise") {
+ noiseLabel = lastLabelIndex;
+ } else {
+ noiseLabel = -1;
}
- public Clustering( Cluster[] clusters ) {
- this.clusters = new AutoExpandVector<Cluster>();
- for (int i = 0; i < clusters.length; i++) {
- this.clusters.add(clusters[i]);
- }
+ ArrayList<Instance>[] sorted_points = (ArrayList<Instance>[]) new ArrayList[numClasses];
+ for (int i = 0; i < numClasses; i++) {
+ sorted_points[i] = new ArrayList<Instance>();
+ }
+ for (Instance point : points) {
+ int clusterid = (int) point.classValue();
+ if (clusterid == noiseLabel)
+ continue;
+ sorted_points[labelMap.get(clusterid)].add((Instance) point);
+ }
+ this.clusters = new AutoExpandVector<Cluster>();
+ for (int i = 0; i < numClasses; i++) {
+ if (sorted_points[i].size() > 0) {
+ SphereCluster s = new SphereCluster(sorted_points[i], dim);
+ s.setId(sorted_points[i].get(0).classValue());
+ s.setGroundTruth(sorted_points[i].get(0).classValue());
+ clusters.add(s);
+ }
+ }
+ }
+
+ public Clustering(ArrayList<DataPoint> points, double overlapThreshold, int initMinPoints) {
+ HashMap<Integer, Integer> labelMap = Clustering.classValues(points);
+ int dim = points.get(0).dataset().numAttributes() - 1;
+
+ int numClasses = labelMap.size();
+ int num = 0;
+
+ ArrayList<DataPoint>[] sorted_points = (ArrayList<DataPoint>[]) new ArrayList[numClasses];
+ for (int i = 0; i < numClasses; i++) {
+ sorted_points[i] = new ArrayList<DataPoint>();
+ }
+ for (DataPoint point : points) {
+ int clusterid = (int) point.classValue();
+ if (clusterid == -1)
+ continue;
+ sorted_points[labelMap.get(clusterid)].add(point);
+ num++;
}
- public Clustering(List<? extends Instance> points){
- HashMap<Integer, Integer> labelMap = classValues(points);
- int dim = points.get(0).dataset().numAttributes()-1;
+ clusters = new AutoExpandVector<Cluster>();
+ int microID = 0;
+ for (int i = 0; i < numClasses; i++) {
+ ArrayList<SphereCluster> microByClass = new ArrayList<SphereCluster>();
+ ArrayList<DataPoint> pointInCluster = new ArrayList<DataPoint>();
+ ArrayList<ArrayList<Instance>> pointInMicroClusters = new ArrayList();
- int numClasses = labelMap.size();
- int noiseLabel;
-
- Attribute classLabel = points.get(0).dataset().classAttribute();
- int lastLabelIndex = classLabel.numValues() - 1;
- if (classLabel.value(lastLabelIndex) == "noise") {
- noiseLabel = lastLabelIndex;
- } else {
- noiseLabel = -1;
+ pointInCluster.addAll(sorted_points[i]);
+ while (pointInCluster.size() > 0) {
+ ArrayList<Instance> micro_points = new ArrayList<Instance>();
+ for (int j = 0; j < initMinPoints && !pointInCluster.isEmpty(); j++) {
+ micro_points.add((Instance) pointInCluster.get(0));
+ pointInCluster.remove(0);
}
-
- ArrayList<Instance>[] sorted_points = (ArrayList<Instance>[]) new ArrayList[numClasses];
- for (int i = 0; i < numClasses; i++) {
- sorted_points[i] = new ArrayList<Instance>();
- }
- for (Instance point : points) {
- int clusterid = (int)point.classValue();
- if(clusterid == noiseLabel) continue;
- sorted_points[labelMap.get(clusterid)].add((Instance)point);
- }
- this.clusters = new AutoExpandVector<Cluster>();
- for (int i = 0; i < numClasses; i++) {
- if(sorted_points[i].size()>0){
- SphereCluster s = new SphereCluster(sorted_points[i],dim);
- s.setId(sorted_points[i].get(0).classValue());
- s.setGroundTruth(sorted_points[i].get(0).classValue());
- clusters.add(s);
+ if (micro_points.size() > 0) {
+ SphereCluster s = new SphereCluster(micro_points, dim);
+ for (int c = 0; c < microByClass.size(); c++) {
+ if (((SphereCluster) microByClass.get(c)).overlapRadiusDegree(s) > overlapThreshold) {
+ micro_points.addAll(pointInMicroClusters.get(c));
+ s = new SphereCluster(micro_points, dim);
+ pointInMicroClusters.remove(c);
+ microByClass.remove(c);
+ // System.out.println("Removing redundant cluster based on radius overlap"+c);
}
- }
- }
-
- public Clustering(ArrayList<DataPoint> points, double overlapThreshold, int initMinPoints){
- HashMap<Integer, Integer> labelMap = Clustering.classValues(points);
- int dim = points.get(0).dataset().numAttributes()-1;
-
- int numClasses = labelMap.size();
- int num = 0;
-
- ArrayList<DataPoint>[] sorted_points = (ArrayList<DataPoint>[]) new ArrayList[numClasses];
- for (int i = 0; i < numClasses; i++) {
- sorted_points[i] = new ArrayList<DataPoint>();
- }
- for (DataPoint point : points) {
- int clusterid = (int)point.classValue();
- if(clusterid == -1) continue;
- sorted_points[labelMap.get(clusterid)].add(point);
- num++;
- }
-
- clusters = new AutoExpandVector<Cluster>();
- int microID = 0;
- for (int i = 0; i < numClasses; i++) {
- ArrayList<SphereCluster> microByClass = new ArrayList<SphereCluster>();
- ArrayList<DataPoint> pointInCluster = new ArrayList<DataPoint>();
- ArrayList<ArrayList<Instance>> pointInMicroClusters = new ArrayList();
-
- pointInCluster.addAll(sorted_points[i]);
- while(pointInCluster.size()>0){
- ArrayList<Instance> micro_points = new ArrayList<Instance>();
- for (int j = 0; j < initMinPoints && !pointInCluster.isEmpty(); j++) {
- micro_points.add((Instance)pointInCluster.get(0));
- pointInCluster.remove(0);
- }
- if(micro_points.size() > 0){
- SphereCluster s = new SphereCluster(micro_points,dim);
- for (int c = 0; c < microByClass.size(); c++) {
- if(((SphereCluster)microByClass.get(c)).overlapRadiusDegree(s) > overlapThreshold ){
- micro_points.addAll(pointInMicroClusters.get(c));
- s = new SphereCluster(micro_points,dim);
- pointInMicroClusters.remove(c);
- microByClass.remove(c);
- //System.out.println("Removing redundant cluster based on radius overlap"+c);
- }
- }
- for (int j = 0; j < pointInCluster.size(); j++) {
- Instance instance = pointInCluster.get(j);
- if(s.getInclusionProbability(instance) > 0.8){
- pointInCluster.remove(j);
- micro_points.add(instance);
- }
- }
- s.setWeight(micro_points.size());
- microByClass.add(s);
- pointInMicroClusters.add(micro_points);
- microID++;
- }
+ }
+ for (int j = 0; j < pointInCluster.size(); j++) {
+ Instance instance = pointInCluster.get(j);
+ if (s.getInclusionProbability(instance) > 0.8) {
+ pointInCluster.remove(j);
+ micro_points.add(instance);
}
- //
- boolean changed = true;
- while(changed){
- changed = false;
- for(int c = 0; c < microByClass.size(); c++) {
- for(int c1 = c+1; c1 < microByClass.size(); c1++) {
- double overlap = microByClass.get(c).overlapRadiusDegree(microByClass.get(c1));
-// System.out.println("Overlap C"+(clustering.size()+c)+" ->C"+(clustering.size()+c1)+": "+overlap);
- if(overlap > overlapThreshold){
- pointInMicroClusters.get(c).addAll(pointInMicroClusters.get(c1));
- SphereCluster s = new SphereCluster(pointInMicroClusters.get(c),dim);
- microByClass.set(c, s);
- pointInMicroClusters.remove(c1);
- microByClass.remove(c1);
- changed = true;
- break;
- }
- }
- }
+ }
+ s.setWeight(micro_points.size());
+ microByClass.add(s);
+ pointInMicroClusters.add(micro_points);
+ microID++;
+ }
+ }
+ //
+ boolean changed = true;
+ while (changed) {
+ changed = false;
+ for (int c = 0; c < microByClass.size(); c++) {
+ for (int c1 = c + 1; c1 < microByClass.size(); c1++) {
+ double overlap = microByClass.get(c).overlapRadiusDegree(microByClass.get(c1));
+ // System.out.println("Overlap C"+(clustering.size()+c)+" ->C"+(clustering.size()+c1)+": "+overlap);
+ if (overlap > overlapThreshold) {
+ pointInMicroClusters.get(c).addAll(pointInMicroClusters.get(c1));
+ SphereCluster s = new SphereCluster(pointInMicroClusters.get(c), dim);
+ microByClass.set(c, s);
+ pointInMicroClusters.remove(c1);
+ microByClass.remove(c1);
+ changed = true;
+ break;
}
- for (int j = 0; j < microByClass.size(); j++) {
- microByClass.get(j).setGroundTruth(sorted_points[i].get(0).classValue());
- clusters.add(microByClass.get(j));
- }
-
+ }
}
- for (int j = 0; j < clusters.size(); j++) {
- clusters.get(j).setId(j);
+ }
+ for (int j = 0; j < microByClass.size(); j++) {
+ microByClass.get(j).setGroundTruth(sorted_points[i].get(0).classValue());
+ clusters.add(microByClass.get(j));
+ }
+
+ }
+ for (int j = 0; j < clusters.size(); j++) {
+ clusters.get(j).setId(j);
+ }
+
+ }
+
+ /**
+ * @param points
+ * @return an array with the min and max class label value
+ */
+ public static HashMap<Integer, Integer> classValues(List<? extends Instance> points) {
+ HashMap<Integer, Integer> classes = new HashMap<Integer, Integer>();
+ int workcluster = 0;
+ boolean hasnoise = false;
+ for (int i = 0; i < points.size(); i++) {
+ int label = (int) points.get(i).classValue();
+ if (label == -1) {
+ hasnoise = true;
+ }
+ else {
+ if (!classes.containsKey(label)) {
+ classes.put(label, workcluster);
+ workcluster++;
}
-
+ }
}
+ if (hasnoise)
+ classes.put(-1, workcluster);
+ return classes;
+ }
- /**
- * @param points
- * @return an array with the min and max class label value
- */
- public static HashMap<Integer, Integer> classValues(List<? extends Instance> points){
- HashMap<Integer,Integer> classes = new HashMap<Integer, Integer>();
- int workcluster = 0;
- boolean hasnoise = false;
- for (int i = 0; i < points.size(); i++) {
- int label = (int) points.get(i).classValue();
- if(label == -1){
- hasnoise = true;
- }
- else{
- if(!classes.containsKey(label)){
- classes.put(label,workcluster);
- workcluster++;
- }
- }
- }
- if(hasnoise)
- classes.put(-1,workcluster);
- return classes;
+ public Clustering(AutoExpandVector<Cluster> clusters) {
+ this.clusters = clusters;
+ }
+
+ /**
+ * add a cluster to the clustering
+ */
+ public void add(Cluster cluster) {
+ clusters.add(cluster);
+ }
+
+ /**
+ * remove a cluster from the clustering
+ */
+ public void remove(int index) {
+ if (index < clusters.size()) {
+ clusters.remove(index);
}
+ }
- public Clustering(AutoExpandVector<Cluster> clusters) {
- this.clusters = clusters;
+ /**
+ * remove a cluster from the clustering
+ */
+ public Cluster get(int index) {
+ if (index < clusters.size()) {
+ return clusters.get(index);
}
+ return null;
+ }
+ /**
+ * @return the <code>Clustering</code> as an AutoExpandVector
+ */
+ public AutoExpandVector<Cluster> getClustering() {
+ return clusters;
+ }
- /**
- * add a cluster to the clustering
- */
- public void add(Cluster cluster){
- clusters.add(cluster);
+ /**
+ * @return A deepcopy of the <code>Clustering</code> as an AutoExpandVector
+ */
+ public AutoExpandVector<Cluster> getClusteringCopy() {
+ return (AutoExpandVector<Cluster>) clusters.copy();
+ }
+
+ /**
+ * @return the number of clusters
+ */
+ public int size() {
+ return clusters.size();
+ }
+
+ /**
+ * @return the number of dimensions of this clustering
+ */
+ public int dimension() {
+ assert (clusters.size() != 0);
+ return clusters.get(0).getCenter().length;
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int i) {
+ sb.append("Clustering Object");
+ }
+
+ public double getMaxInclusionProbability(Instance point) {
+ double maxInclusion = 0.0;
+ for (int i = 0; i < clusters.size(); i++) {
+ maxInclusion = Math.max(clusters.get(i).getInclusionProbability(point),
+ maxInclusion);
}
-
- /**
- * remove a cluster from the clustering
- */
- public void remove(int index){
- if(index < clusters.size()){
- clusters.remove(index);
- }
- }
-
- /**
- * remove a cluster from the clustering
- */
- public Cluster get(int index){
- if(index < clusters.size()){
- return clusters.get(index);
- }
- return null;
- }
-
- /**
- * @return the <code>Clustering</code> as an AutoExpandVector
- */
- public AutoExpandVector<Cluster> getClustering() {
- return clusters;
- }
-
- /**
- * @return A deepcopy of the <code>Clustering</code> as an AutoExpandVector
- */
- public AutoExpandVector<Cluster> getClusteringCopy() {
- return (AutoExpandVector<Cluster>)clusters.copy();
- }
-
-
- /**
- * @return the number of clusters
- */
- public int size() {
- return clusters.size();
- }
-
- /**
- * @return the number of dimensions of this clustering
- */
- public int dimension() {
- assert (clusters.size() != 0);
- return clusters.get(0).getCenter().length;
- }
-
- @Override
- public void getDescription(StringBuilder sb, int i) {
- sb.append("Clustering Object");
- }
-
-
-
- public double getMaxInclusionProbability(Instance point) {
- double maxInclusion = 0.0;
- for (int i = 0; i < clusters.size(); i++) {
- maxInclusion = Math.max(clusters.get(i).getInclusionProbability(point),
- maxInclusion);
- }
- return maxInclusion;
- }
-
-
-
-
+ return maxInclusion;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Miniball.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Miniball.java
index 45c947e..f27a6ad 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Miniball.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Miniball.java
@@ -27,58 +27,58 @@
public class Miniball {
- private int dimension;
- private com.dreizak.miniball.highdim.Miniball mb;
- private PointStorage pointSet;
+ private int dimension;
+ private com.dreizak.miniball.highdim.Miniball mb;
+ private PointStorage pointSet;
- public Miniball(int dimension) {
- this.dimension = dimension;
+ public Miniball(int dimension) {
+ this.dimension = dimension;
+ }
+
+ void clear() {
+ this.pointSet = new PointStorage(this.dimension);
+ }
+
+ void check_in(double[] array) {
+ this.pointSet.add(array);
+ }
+
+ double[] center() {
+ return this.mb.center();
+ }
+
+ double radius() {
+ return this.mb.radius();
+ }
+
+ void build() {
+ this.mb = new com.dreizak.miniball.highdim.Miniball(this.pointSet);
+ }
+
+ public class PointStorage implements PointSet {
+
+ protected int dimension;
+ protected List<double[]> L;
+
+ public PointStorage(int dimension) {
+ this.dimension = dimension;
+ this.L = new ArrayList<double[]>();
}
- void clear() {
- this.pointSet = new PointStorage(this.dimension);
+ public void add(double[] array) {
+ this.L.add(array);
}
- void check_in(double[] array) {
- this.pointSet.add(array);
+ public int size() {
+ return L.size();
}
- double[] center() {
- return this.mb.center();
+ public int dimension() {
+ return dimension;
}
- double radius() {
- return this.mb.radius();
+ public double coord(int point, int coordinate) {
+ return L.get(point)[coordinate];
}
-
- void build() {
- this.mb = new com.dreizak.miniball.highdim.Miniball(this.pointSet);
- }
-
- public class PointStorage implements PointSet {
-
- protected int dimension;
- protected List<double[]> L;
-
- public PointStorage(int dimension) {
- this.dimension = dimension;
- this.L = new ArrayList<double[]>();
- }
-
- public void add(double[] array) {
- this.L.add(array);
- }
-
- public int size() {
- return L.size();
- }
-
- public int dimension() {
- return dimension;
- }
-
- public double coord(int point, int coordinate) {
- return L.get(point)[coordinate];
- }
- }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/SphereCluster.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/SphereCluster.java
index 558b0f1..4a5cc97 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/SphereCluster.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/SphereCluster.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.cluster;
/*
@@ -28,346 +27,341 @@
/**
* A simple implementation of the <code>Cluster</code> interface representing
- * spherical clusters. The inclusion probability is one inside the sphere and zero
- * everywhere else.
- *
+ * spherical clusters. The inclusion probability is one inside the sphere and
+ * zero everywhere else.
+ *
*/
public class SphereCluster extends Cluster {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- private double[] center;
- private double radius;
- private double weight;
+ private double[] center;
+ private double radius;
+ private double weight;
+ public SphereCluster(double[] center, double radius) {
+ this(center, radius, 1.0);
+ }
- public SphereCluster(double[] center, double radius) {
- this( center, radius, 1.0 );
- }
+ public SphereCluster() {
+ }
- public SphereCluster() {
- }
+ public SphereCluster(double[] center, double radius, double weightedSize) {
+ this();
+ this.center = center;
+ this.radius = radius;
+ this.weight = weightedSize;
+ }
- public SphereCluster( double[] center, double radius, double weightedSize) {
- this();
- this.center = center;
- this.radius = radius;
- this.weight = weightedSize;
- }
+ public SphereCluster(int dimensions, double radius, Random random) {
+ this();
+ this.center = new double[dimensions];
+ this.radius = radius;
- public SphereCluster(int dimensions, double radius, Random random) {
- this();
- this.center = new double[dimensions];
- this.radius = radius;
+ // Position randomly but keep hypersphere inside the boundaries
+ double interval = 1.0 - 2 * radius;
+ for (int i = 0; i < center.length; i++) {
+ this.center[i] = (random.nextDouble() * interval) + radius;
+ }
+ this.weight = 0.0;
+ }
- // Position randomly but keep hypersphere inside the boundaries
- double interval = 1.0 - 2 * radius;
- for (int i = 0; i < center.length; i++) {
- this.center[i] = (random.nextDouble() * interval) + radius;
- }
- this.weight = 0.0;
- }
+ public SphereCluster(List<? extends Instance> instances, int dimension) {
+ this();
+ if (instances == null || instances.size() <= 0)
+ return;
+ weight = instances.size();
- public SphereCluster(List<?extends Instance> instances, int dimension){
- this();
- if(instances == null || instances.size() <= 0)
- return;
+ Miniball mb = new Miniball(dimension);
+ mb.clear();
- weight = instances.size();
+ for (Instance instance : instances) {
+ mb.check_in(instance.toDoubleArray());
+ }
- Miniball mb = new Miniball(dimension);
- mb.clear();
+ mb.build();
+ center = mb.center();
+ radius = mb.radius();
+ mb.clear();
+ }
- for (Instance instance : instances) {
- mb.check_in(instance.toDoubleArray());
- }
+ /**
+ * Checks whether two <code>SphereCluster</code> overlap based on radius NOTE:
+ * overlapRadiusDegree only calculates the overlap based on the centers and
+ * the radi, so not the real overlap
+ *
+ * TODO: should we do this by MC to get the real overlap???
+ *
+ * @param other
+ * @return
+ */
- mb.build();
- center = mb.center();
- radius = mb.radius();
- mb.clear();
- }
+ public double overlapRadiusDegree(SphereCluster other) {
+ double[] center0 = getCenter();
+ double radius0 = getRadius();
- /**
- * Checks whether two <code>SphereCluster</code> overlap based on radius
- * NOTE: overlapRadiusDegree only calculates the overlap based
- * on the centers and the radi, so not the real overlap
- *
- * TODO: should we do this by MC to get the real overlap???
- *
- * @param other
- * @return
+ double[] center1 = other.getCenter();
+ double radius1 = other.getRadius();
+
+ double radiusBig;
+ double radiusSmall;
+ if (radius0 < radius1) {
+ radiusBig = radius1;
+ radiusSmall = radius0;
+ }
+ else {
+ radiusBig = radius0;
+ radiusSmall = radius1;
+ }
+
+ double dist = 0;
+ for (int i = 0; i < center0.length; i++) {
+ double delta = center0[i] - center1[i];
+ dist += delta * delta;
+ }
+ dist = Math.sqrt(dist);
+
+ if (dist > radiusSmall + radiusBig)
+ return 0;
+ if (dist + radiusSmall <= radiusBig) {
+ // one lies within the other
+ return 1;
+ }
+ else {
+ return (radiusSmall + radiusBig - dist) / (2 * radiusSmall);
+ }
+ }
+
+ public void combine(SphereCluster cluster) {
+ double[] center = getCenter();
+ double[] newcenter = new double[center.length];
+ double[] other_center = cluster.getCenter();
+ double other_weight = cluster.getWeight();
+ double other_radius = cluster.getRadius();
+
+ for (int i = 0; i < center.length; i++) {
+ newcenter[i] = (center[i] * getWeight() + other_center[i] * other_weight) / (getWeight() + other_weight);
+ }
+
+ center = newcenter;
+ double r_0 = getRadius() + Math.abs(distance(center, newcenter));
+ double r_1 = other_radius + Math.abs(distance(other_center, newcenter));
+ radius = Math.max(r_0, r_1);
+ weight += other_weight;
+ }
+
+ public void merge(SphereCluster cluster) {
+ double[] c0 = getCenter();
+ double w0 = getWeight();
+ double r0 = getRadius();
+
+ double[] c1 = cluster.getCenter();
+ double w1 = cluster.getWeight();
+ double r1 = cluster.getRadius();
+
+ // vector
+ double[] v = new double[c0.length];
+ // center distance
+ double d = 0;
+
+ for (int i = 0; i < c0.length; i++) {
+ v[i] = c0[i] - c1[i];
+ d += v[i] * v[i];
+ }
+ d = Math.sqrt(d);
+
+ double r;
+ double[] c = new double[c0.length];
+
+ // one lays within the others
+ if (d + r0 <= r1 || d + r1 <= r0) {
+ if (d + r0 <= r1) {
+ r = r1;
+ c = c1;
+ }
+ else {
+ r = r0;
+ c = c0;
+ }
+ }
+ else {
+ r = (r0 + r1 + d) / 2.0;
+ for (int i = 0; i < c.length; i++) {
+ c[i] = c1[i] - v[i] / d * (r1 - r);
+ }
+ }
+
+ setCenter(c);
+ setRadius(r);
+ setWeight(w0 + w1);
+
+ }
+
+ @Override
+ public double[] getCenter() {
+ double[] copy = new double[center.length];
+ System.arraycopy(center, 0, copy, 0, center.length);
+ return copy;
+ }
+
+ public void setCenter(double[] center) {
+ this.center = center;
+ }
+
+ public double getRadius() {
+ return radius;
+ }
+
+ public void setRadius(double radius) {
+ this.radius = radius;
+ }
+
+ @Override
+ public double getWeight() {
+ return weight;
+ }
+
+ public void setWeight(double weight) {
+ this.weight = weight;
+ }
+
+ @Override
+ public double getInclusionProbability(Instance instance) {
+ if (getCenterDistance(instance) <= getRadius()) {
+ return 1.0;
+ }
+ return 0.0;
+ }
+
+ public double getCenterDistance(Instance instance) {
+ double distance = 0.0;
+ // get the center through getCenter so subclass have a chance
+ double[] center = getCenter();
+ for (int i = 0; i < center.length; i++) {
+ double d = center[i] - instance.value(i);
+ distance += d * d;
+ }
+ return Math.sqrt(distance);
+ }
+
+ public double getCenterDistance(SphereCluster other) {
+ return distance(getCenter(), other.getCenter());
+ }
+
+ /*
+ * the minimal distance between the surface of two clusters. is negative if
+ * the two clusters overlap
+ */
+ public double getHullDistance(SphereCluster other) {
+ double distance;
+ // get the center through getCenter so subclass have a chance
+ double[] center0 = getCenter();
+ double[] center1 = other.getCenter();
+ distance = distance(center0, center1);
+
+ distance = distance - getRadius() - other.getRadius();
+ return distance;
+ }
+
+ /*
*/
+ /**
+ * When a clusters looses points the new minimal bounding sphere can be partly
+ * outside of the originating cluster. If a another cluster is right next to
+ * the original cluster (without overlapping), the new cluster can be
+ * overlapping with this second cluster. OverlapSave will tell you if the
+ * current cluster can degenerate so much that it overlaps with cluster
+ * 'other'
+ *
+ * @param other
+ * the potentially overlapping cluster
+ * @return true if cluster can potentially overlap
+ */
+ public boolean overlapSave(SphereCluster other) {
+ // use basic geometry to figure out the maximal degenerated cluster
+ // comes down to Max(radius *(sin alpha + cos alpha)) which is
+ double minDist = Math.sqrt(2) * (getRadius() + other.getRadius());
+ double diff = getCenterDistance(other) - minDist;
- public double overlapRadiusDegree(SphereCluster other) {
+ return diff > 0;
+ }
+ private double distance(double[] v1, double[] v2) {
+ double distance = 0.0;
+ double[] center = getCenter();
+ for (int i = 0; i < center.length; i++) {
+ double d = v1[i] - v2[i];
+ distance += d * d;
+ }
+ return Math.sqrt(distance);
+ }
- double[] center0 = getCenter();
- double radius0 = getRadius();
+ public double[] getDistanceVector(Instance instance) {
+ return distanceVector(getCenter(), instance.toDoubleArray());
+ }
- double[] center1 = other.getCenter();
- double radius1 = other.getRadius();
+ public double[] getDistanceVector(SphereCluster other) {
+ return distanceVector(getCenter(), other.getCenter());
+ }
- double radiusBig;
- double radiusSmall;
- if(radius0 < radius1){
- radiusBig = radius1;
- radiusSmall = radius0;
- }
- else{
- radiusBig = radius0;
- radiusSmall = radius1;
- }
+ private double[] distanceVector(double[] v1, double[] v2) {
+ double[] v = new double[v1.length];
+ for (int i = 0; i < v1.length; i++) {
+ v[i] = v2[i] - v1[i];
+ }
+ return v;
+ }
- double dist = 0;
- for (int i = 0; i < center0.length; i++) {
- double delta = center0[i] - center1[i];
- dist += delta * delta;
- }
- dist = Math.sqrt(dist);
+ /**
+ * Samples this cluster by returning a point from inside it.
+ *
+ * @param random
+ * a random number source
+ * @return a point that lies inside this cluster
+ */
+ public Instance sample(Random random) {
+ // Create sample in hypersphere coordinates
+ // get the center through getCenter so subclass have a chance
+ double[] center = getCenter();
- if(dist > radiusSmall + radiusBig)
- return 0;
- if(dist + radiusSmall <= radiusBig){
- //one lies within the other
- return 1;
- }
- else{
- return (radiusSmall+radiusBig-dist)/(2*radiusSmall);
- }
- }
+ final int dimensions = center.length;
- public void combine(SphereCluster cluster) {
- double[] center = getCenter();
- double[] newcenter = new double[center.length];
- double[] other_center = cluster.getCenter();
- double other_weight = cluster.getWeight();
- double other_radius = cluster.getRadius();
+ final double sin[] = new double[dimensions - 1];
+ final double cos[] = new double[dimensions - 1];
+ final double length = random.nextDouble() * getRadius();
- for (int i = 0; i < center.length; i++) {
- newcenter[i] = (center[i]*getWeight()+other_center[i]*other_weight)/(getWeight()+other_weight);
- }
+ double lastValue = 1.0;
+ for (int i = 0; i < dimensions - 1; i++) {
+ double angle = random.nextDouble() * 2 * Math.PI;
+ sin[i] = lastValue * Math.sin(angle); // Store cumulative values
+ cos[i] = Math.cos(angle);
+ lastValue = sin[i];
+ }
- center = newcenter;
- double r_0 = getRadius() + Math.abs(distance(center, newcenter));
- double r_1 = other_radius + Math.abs(distance(other_center, newcenter));
- radius = Math.max(r_0, r_1);
- weight+= other_weight;
- }
+ // Calculate cartesian coordinates
+ double res[] = new double[dimensions];
- public void merge(SphereCluster cluster) {
- double[] c0 = getCenter();
- double w0 = getWeight();
- double r0 = getRadius();
+ // First value uses only cosines
+ res[0] = center[0] + length * cos[0];
- double[] c1 = cluster.getCenter();
- double w1 = cluster.getWeight();
- double r1 = cluster.getRadius();
+ // Loop through 'middle' coordinates which use cosines and sines
+ for (int i = 1; i < dimensions - 1; i++) {
+ res[i] = center[i] + length * sin[i - 1] * cos[i];
+ }
- //vector
- double[] v = new double[c0.length];
- //center distance
- double d = 0;
+ // Last value uses only sines
+ res[dimensions - 1] = center[dimensions - 1] + length * sin[dimensions - 2];
- for (int i = 0; i < c0.length; i++) {
- v[i] = c0[i] - c1[i];
- d += v[i] * v[i];
- }
- d = Math.sqrt(d);
+ return new DenseInstance(1.0, res);
+ }
-
-
- double r;
- double[] c = new double[c0.length];
-
- //one lays within the others
- if(d + r0 <= r1 || d + r1 <= r0){
- if(d + r0 <= r1){
- r = r1;
- c = c1;
- }
- else{
- r = r0;
- c = c0;
- }
- }
- else{
- r = (r0 + r1 + d)/2.0;
- for (int i = 0; i < c.length; i++) {
- c[i] = c1[i] - v[i]/d * (r1-r);
- }
- }
-
- setCenter(c);
- setRadius(r);
- setWeight(w0+w1);
-
- }
-
- @Override
- public double[] getCenter() {
- double[] copy = new double[center.length];
- System.arraycopy(center, 0, copy, 0, center.length);
- return copy;
- }
-
- public void setCenter(double[] center) {
- this.center = center;
- }
-
- public double getRadius() {
- return radius;
- }
-
- public void setRadius( double radius ) {
- this.radius = radius;
- }
-
- @Override
- public double getWeight() {
- return weight;
- }
-
- public void setWeight( double weight ) {
- this.weight = weight;
- }
-
- @Override
- public double getInclusionProbability(Instance instance) {
- if (getCenterDistance(instance) <= getRadius()) {
- return 1.0;
- }
- return 0.0;
- }
-
- public double getCenterDistance(Instance instance) {
- double distance = 0.0;
- //get the center through getCenter so subclass have a chance
- double[] center = getCenter();
- for (int i = 0; i < center.length; i++) {
- double d = center[i] - instance.value(i);
- distance += d * d;
- }
- return Math.sqrt(distance);
- }
-
- public double getCenterDistance(SphereCluster other) {
- return distance(getCenter(), other.getCenter());
- }
-
- /*
- * the minimal distance between the surface of two clusters.
- * is negative if the two clusters overlap
- */
- public double getHullDistance(SphereCluster other) {
- double distance;
- //get the center through getCenter so subclass have a chance
- double[] center0 = getCenter();
- double[] center1 = other.getCenter();
- distance = distance(center0, center1);
-
- distance = distance - getRadius() - other.getRadius();
- return distance;
- }
-
- /*
- */
- /**
- * When a clusters looses points the new minimal bounding sphere can be
- * partly outside of the originating cluster. If a another cluster is
- * right next to the original cluster (without overlapping), the new
- * cluster can be overlapping with this second cluster. OverlapSave
- * will tell you if the current cluster can degenerate so much that it
- * overlaps with cluster 'other'
- *
- * @param other the potentially overlapping cluster
- * @return true if cluster can potentially overlap
- */
- public boolean overlapSave(SphereCluster other){
- //use basic geometry to figure out the maximal degenerated cluster
- //comes down to Max(radius *(sin alpha + cos alpha)) which is
- double minDist = Math.sqrt(2)*(getRadius() + other.getRadius());
- double diff = getCenterDistance(other) - minDist;
-
- return diff > 0;
- }
-
- private double distance(double[] v1, double[] v2){
- double distance = 0.0;
- double[] center = getCenter();
- for (int i = 0; i < center.length; i++) {
- double d = v1[i] - v2[i];
- distance += d * d;
- }
- return Math.sqrt(distance);
- }
-
- public double[] getDistanceVector(Instance instance){
- return distanceVector(getCenter(), instance.toDoubleArray());
- }
-
- public double[] getDistanceVector(SphereCluster other){
- return distanceVector(getCenter(), other.getCenter());
- }
-
- private double[] distanceVector(double[] v1, double[] v2){
- double[] v = new double[v1.length];
- for (int i = 0; i < v1.length; i++) {
- v[i] = v2[i] - v1[i];
- }
- return v;
- }
-
-
- /**
- * Samples this cluster by returning a point from inside it.
- * @param random a random number source
- * @return a point that lies inside this cluster
- */
- public Instance sample(Random random) {
- // Create sample in hypersphere coordinates
- //get the center through getCenter so subclass have a chance
- double[] center = getCenter();
-
- final int dimensions = center.length;
-
- final double sin[] = new double[dimensions - 1];
- final double cos[] = new double[dimensions - 1];
- final double length = random.nextDouble() * getRadius();
-
- double lastValue = 1.0;
- for (int i = 0; i < dimensions-1; i++) {
- double angle = random.nextDouble() * 2 * Math.PI;
- sin[i] = lastValue * Math.sin( angle ); // Store cumulative values
- cos[i] = Math.cos( angle );
- lastValue = sin[i];
- }
-
- // Calculate cartesian coordinates
- double res[] = new double[dimensions];
-
- // First value uses only cosines
- res[0] = center[0] + length*cos[0];
-
- // Loop through 'middle' coordinates which use cosines and sines
- for (int i = 1; i < dimensions-1; i++) {
- res[i] = center[i] + length*sin[i-1]*cos[i];
- }
-
- // Last value uses only sines
- res[dimensions-1] = center[dimensions-1] + length*sin[dimensions-2];
-
- return new DenseInstance(1.0, res);
- }
-
- @Override
- protected void getClusterSpecificInfo(List<String> infoTitle, List<String> infoValue) {
- super.getClusterSpecificInfo(infoTitle, infoValue);
- infoTitle.add("Radius");
- infoValue.add(Double.toString(getRadius()));
- }
-
+ @Override
+ protected void getClusterSpecificInfo(List<String> infoTitle, List<String> infoValue) {
+ super.getClusterSpecificInfo(infoTitle, infoValue);
+ infoTitle.add("Radius");
+ infoValue.add(Double.toString(getRadius()));
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/AbstractClusterer.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/AbstractClusterer.java
index a1d80d1..b023648 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/AbstractClusterer.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/AbstractClusterer.java
@@ -37,262 +37,262 @@
import com.yahoo.labs.samoa.instances.Instances;
public abstract class AbstractClusterer extends AbstractOptionHandler
- implements Clusterer {
-
- @Override
- public String getPurposeString() {
- return "MOA Clusterer: " + getClass().getCanonicalName();
- }
+ implements Clusterer {
- protected InstancesHeader modelContext;
+ @Override
+ public String getPurposeString() {
+ return "MOA Clusterer: " + getClass().getCanonicalName();
+ }
- protected double trainingWeightSeenByModel = 0.0;
+ protected InstancesHeader modelContext;
- protected int randomSeed = 1;
+ protected double trainingWeightSeenByModel = 0.0;
- protected IntOption randomSeedOption;
+ protected int randomSeed = 1;
- public FlagOption evaluateMicroClusteringOption;
+ protected IntOption randomSeedOption;
- protected Random clustererRandom;
+ public FlagOption evaluateMicroClusteringOption;
- protected Clustering clustering;
-
- public AbstractClusterer() {
- if (isRandomizable()) {
- this.randomSeedOption = new IntOption("randomSeed", 'r',
- "Seed for random behaviour of the Clusterer.", 1);
- }
+ protected Random clustererRandom;
- if( implementsMicroClusterer()){
- this.evaluateMicroClusteringOption =
- new FlagOption("evaluateMicroClustering", 'M',
- "Evaluate the underlying microclustering instead of the macro clustering");
+ protected Clustering clustering;
+
+ public AbstractClusterer() {
+ if (isRandomizable()) {
+ this.randomSeedOption = new IntOption("randomSeed", 'r',
+ "Seed for random behaviour of the Clusterer.", 1);
+ }
+
+ if (implementsMicroClusterer()) {
+ this.evaluateMicroClusteringOption =
+ new FlagOption("evaluateMicroClustering", 'M',
+ "Evaluate the underlying microclustering instead of the macro clustering");
+ }
+ }
+
+ @Override
+ public void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ if (this.randomSeedOption != null) {
+ this.randomSeed = this.randomSeedOption.getValue();
+ }
+ if (!trainingHasStarted()) {
+ resetLearning();
+ }
+ clustering = new Clustering();
+ }
+
+ public void setModelContext(InstancesHeader ih) {
+ if ((ih != null) && (ih.classIndex() < 0)) {
+ throw new IllegalArgumentException(
+ "Context for a Clusterer must include a class to learn");
+ }
+ if (trainingHasStarted()
+ && (this.modelContext != null)
+ && ((ih == null) || !contextIsCompatible(this.modelContext, ih))) {
+ throw new IllegalArgumentException(
+ "New context is not compatible with existing model");
+ }
+ this.modelContext = ih;
+ }
+
+ public InstancesHeader getModelContext() {
+ return this.modelContext;
+ }
+
+ public void setRandomSeed(int s) {
+ this.randomSeed = s;
+ if (this.randomSeedOption != null) {
+ // keep option consistent
+ this.randomSeedOption.setValue(s);
+ }
+ }
+
+ public boolean trainingHasStarted() {
+ return this.trainingWeightSeenByModel > 0.0;
+ }
+
+ public double trainingWeightSeenByModel() {
+ return this.trainingWeightSeenByModel;
+ }
+
+ public void resetLearning() {
+ this.trainingWeightSeenByModel = 0.0;
+ if (isRandomizable()) {
+ this.clustererRandom = new Random(this.randomSeed);
+ }
+ resetLearningImpl();
+ }
+
+ public void trainOnInstance(Instance inst) {
+ if (inst.weight() > 0.0) {
+ this.trainingWeightSeenByModel += inst.weight();
+ trainOnInstanceImpl(inst);
+ }
+ }
+
+ public Measurement[] getModelMeasurements() {
+ List<Measurement> measurementList = new LinkedList<Measurement>();
+ measurementList.add(new Measurement("model training instances",
+ trainingWeightSeenByModel()));
+ measurementList.add(new Measurement("model serialized size (bytes)",
+ measureByteSize()));
+ Measurement[] modelMeasurements = getModelMeasurementsImpl();
+ if (modelMeasurements != null) {
+ for (Measurement measurement : modelMeasurements) {
+ measurementList.add(measurement);
+ }
+ }
+ // add average of sub-model measurements
+ Clusterer[] subModels = getSubClusterers();
+ if ((subModels != null) && (subModels.length > 0)) {
+ List<Measurement[]> subMeasurements = new LinkedList<Measurement[]>();
+ for (Clusterer subModel : subModels) {
+ if (subModel != null) {
+ subMeasurements.add(subModel.getModelMeasurements());
}
- }
+ }
+ Measurement[] avgMeasurements = Measurement
+ .averageMeasurements(subMeasurements
+ .toArray(new Measurement[subMeasurements.size()][]));
+ for (Measurement measurement : avgMeasurements) {
+ measurementList.add(measurement);
+ }
+ }
+ return measurementList.toArray(new Measurement[measurementList.size()]);
+ }
- @Override
- public void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- if (this.randomSeedOption != null) {
- this.randomSeed = this.randomSeedOption.getValue();
- }
- if (!trainingHasStarted()) {
- resetLearning();
- }
- clustering = new Clustering();
- }
+ public void getDescription(StringBuilder out, int indent) {
+ StringUtils.appendIndented(out, indent, "Model type: ");
+ out.append(this.getClass().getName());
+ StringUtils.appendNewline(out);
+ Measurement.getMeasurementsDescription(getModelMeasurements(), out,
+ indent);
+ StringUtils.appendNewlineIndented(out, indent, "Model description:");
+ StringUtils.appendNewline(out);
+ if (trainingHasStarted()) {
+ getModelDescription(out, indent);
+ } else {
+ StringUtils.appendIndented(out, indent,
+ "Model has not been trained.");
+ }
+ }
- public void setModelContext(InstancesHeader ih) {
- if ((ih != null) && (ih.classIndex() < 0)) {
- throw new IllegalArgumentException(
- "Context for a Clusterer must include a class to learn");
- }
- if (trainingHasStarted()
- && (this.modelContext != null)
- && ((ih == null) || !contextIsCompatible(this.modelContext, ih))) {
- throw new IllegalArgumentException(
- "New context is not compatible with existing model");
- }
- this.modelContext = ih;
- }
+ public Clusterer[] getSubClusterers() {
+ return null;
+ }
- public InstancesHeader getModelContext() {
- return this.modelContext;
- }
+ @Override
+ public Clusterer copy() {
+ return (Clusterer) super.copy();
+ }
- public void setRandomSeed(int s) {
- this.randomSeed = s;
- if (this.randomSeedOption != null) {
- // keep option consistent
- this.randomSeedOption.setValue(s);
- }
- }
+ // public boolean correctlyClassifies(Instance inst) {
+ // return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst
+ // .classValue();
+ // }
- public boolean trainingHasStarted() {
- return this.trainingWeightSeenByModel > 0.0;
- }
+ public String getClassNameString() {
+ return InstancesHeader.getClassNameString(this.modelContext);
+ }
- public double trainingWeightSeenByModel() {
- return this.trainingWeightSeenByModel;
- }
+ public String getClassLabelString(int classLabelIndex) {
+ return InstancesHeader.getClassLabelString(this.modelContext,
+ classLabelIndex);
+ }
- public void resetLearning() {
- this.trainingWeightSeenByModel = 0.0;
- if (isRandomizable()) {
- this.clustererRandom = new Random(this.randomSeed);
- }
- resetLearningImpl();
- }
+ public String getAttributeNameString(int attIndex) {
+ return InstancesHeader.getAttributeNameString(this.modelContext,
+ attIndex);
+ }
- public void trainOnInstance(Instance inst) {
- if (inst.weight() > 0.0) {
- this.trainingWeightSeenByModel += inst.weight();
- trainOnInstanceImpl(inst);
- }
- }
+ public String getNominalValueString(int attIndex, int valIndex) {
+ return InstancesHeader.getNominalValueString(this.modelContext,
+ attIndex, valIndex);
+ }
- public Measurement[] getModelMeasurements() {
- List<Measurement> measurementList = new LinkedList<Measurement>();
- measurementList.add(new Measurement("model training instances",
- trainingWeightSeenByModel()));
- measurementList.add(new Measurement("model serialized size (bytes)",
- measureByteSize()));
- Measurement[] modelMeasurements = getModelMeasurementsImpl();
- if (modelMeasurements != null) {
- for (Measurement measurement : modelMeasurements) {
- measurementList.add(measurement);
- }
- }
- // add average of sub-model measurements
- Clusterer[] subModels = getSubClusterers();
- if ((subModels != null) && (subModels.length > 0)) {
- List<Measurement[]> subMeasurements = new LinkedList<Measurement[]>();
- for (Clusterer subModel : subModels) {
- if (subModel != null) {
- subMeasurements.add(subModel.getModelMeasurements());
- }
- }
- Measurement[] avgMeasurements = Measurement
- .averageMeasurements(subMeasurements
- .toArray(new Measurement[subMeasurements.size()][]));
- for (Measurement measurement : avgMeasurements) {
- measurementList.add(measurement);
- }
- }
- return measurementList.toArray(new Measurement[measurementList.size()]);
- }
-
- public void getDescription(StringBuilder out, int indent) {
- StringUtils.appendIndented(out, indent, "Model type: ");
- out.append(this.getClass().getName());
- StringUtils.appendNewline(out);
- Measurement.getMeasurementsDescription(getModelMeasurements(), out,
- indent);
- StringUtils.appendNewlineIndented(out, indent, "Model description:");
- StringUtils.appendNewline(out);
- if (trainingHasStarted()) {
- getModelDescription(out, indent);
- } else {
- StringUtils.appendIndented(out, indent,
- "Model has not been trained.");
- }
- }
-
- public Clusterer[] getSubClusterers() {
- return null;
- }
-
- @Override
- public Clusterer copy() {
- return (Clusterer) super.copy();
- }
-
-// public boolean correctlyClassifies(Instance inst) {
-// return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst
-// .classValue();
-// }
-
- public String getClassNameString() {
- return InstancesHeader.getClassNameString(this.modelContext);
- }
-
- public String getClassLabelString(int classLabelIndex) {
- return InstancesHeader.getClassLabelString(this.modelContext,
- classLabelIndex);
- }
-
- public String getAttributeNameString(int attIndex) {
- return InstancesHeader.getAttributeNameString(this.modelContext,
- attIndex);
- }
-
- public String getNominalValueString(int attIndex, int valIndex) {
- return InstancesHeader.getNominalValueString(this.modelContext,
- attIndex, valIndex);
- }
-
- // originalContext notnull
- // newContext notnull
- public static boolean contextIsCompatible(InstancesHeader originalContext,
- InstancesHeader newContext) {
- // rule 1: num classes can increase but never decrease
- // rule 2: num attributes can increase but never decrease
- // rule 3: num nominal attribute values can increase but never decrease
- // rule 4: attribute types must stay in the same order (although class
- // can
- // move; is always skipped over)
- // attribute names are free to change, but should always still represent
- // the original attributes
- if (newContext.numClasses() < originalContext.numClasses()) {
- return false; // rule 1
- }
- if (newContext.numAttributes() < originalContext.numAttributes()) {
- return false; // rule 2
- }
- int oPos = 0;
- int nPos = 0;
- while (oPos < originalContext.numAttributes()) {
- if (oPos == originalContext.classIndex()) {
- oPos++;
- if (!(oPos < originalContext.numAttributes())) {
- break;
- }
- }
- if (nPos == newContext.classIndex()) {
- nPos++;
- }
- if (originalContext.attribute(oPos).isNominal()) {
- if (!newContext.attribute(nPos).isNominal()) {
- return false; // rule 4
- }
- if (newContext.attribute(nPos).numValues() < originalContext
- .attribute(oPos).numValues()) {
- return false; // rule 3
- }
- } else {
- assert (originalContext.attribute(oPos).isNumeric());
- if (!newContext.attribute(nPos).isNumeric()) {
- return false; // rule 4
- }
- }
- oPos++;
- nPos++;
- }
- return true; // all checks clear
- }
-
- // reason for ...Impl methods:
- // ease programmer burden by not requiring them to remember calls to super
- // in overridden methods & will produce compiler errors if not overridden
-
- public abstract void resetLearningImpl();
-
- public abstract void trainOnInstanceImpl(Instance inst);
-
- protected abstract Measurement[] getModelMeasurementsImpl();
-
- public abstract void getModelDescription(StringBuilder out, int indent);
-
- protected static int modelAttIndexToInstanceAttIndex(int index,
- Instance inst) {
- return inst.classIndex() > index ? index : index + 1;
- }
-
- protected static int modelAttIndexToInstanceAttIndex(int index,
- Instances insts) {
- return insts.classIndex() > index ? index : index + 1;
- }
-
- public boolean implementsMicroClusterer(){
- return false;
+ // originalContext notnull
+ // newContext notnull
+ public static boolean contextIsCompatible(InstancesHeader originalContext,
+ InstancesHeader newContext) {
+ // rule 1: num classes can increase but never decrease
+ // rule 2: num attributes can increase but never decrease
+ // rule 3: num nominal attribute values can increase but never decrease
+ // rule 4: attribute types must stay in the same order (although class
+ // can
+ // move; is always skipped over)
+ // attribute names are free to change, but should always still represent
+ // the original attributes
+ if (newContext.numClasses() < originalContext.numClasses()) {
+ return false; // rule 1
+ }
+ if (newContext.numAttributes() < originalContext.numAttributes()) {
+ return false; // rule 2
+ }
+ int oPos = 0;
+ int nPos = 0;
+ while (oPos < originalContext.numAttributes()) {
+ if (oPos == originalContext.classIndex()) {
+ oPos++;
+ if (!(oPos < originalContext.numAttributes())) {
+ break;
}
-
- public boolean keepClassLabel(){
- return false;
+ }
+ if (nPos == newContext.classIndex()) {
+ nPos++;
+ }
+ if (originalContext.attribute(oPos).isNominal()) {
+ if (!newContext.attribute(nPos).isNominal()) {
+ return false; // rule 4
}
-
- public Clustering getMicroClusteringResult(){
- return null;
- };
+ if (newContext.attribute(nPos).numValues() < originalContext
+ .attribute(oPos).numValues()) {
+ return false; // rule 3
+ }
+ } else {
+ assert (originalContext.attribute(oPos).isNumeric());
+ if (!newContext.attribute(nPos).isNumeric()) {
+ return false; // rule 4
+ }
+ }
+ oPos++;
+ nPos++;
+ }
+ return true; // all checks clear
+ }
+
+ // reason for ...Impl methods:
+ // ease programmer burden by not requiring them to remember calls to super
+ // in overridden methods & will produce compiler errors if not overridden
+
+ public abstract void resetLearningImpl();
+
+ public abstract void trainOnInstanceImpl(Instance inst);
+
+ protected abstract Measurement[] getModelMeasurementsImpl();
+
+ public abstract void getModelDescription(StringBuilder out, int indent);
+
+ protected static int modelAttIndexToInstanceAttIndex(int index,
+ Instance inst) {
+ return inst.classIndex() > index ? index : index + 1;
+ }
+
+ protected static int modelAttIndexToInstanceAttIndex(int index,
+ Instances insts) {
+ return insts.classIndex() > index ? index : index + 1;
+ }
+
+ public boolean implementsMicroClusterer() {
+ return false;
+ }
+
+ public boolean keepClassLabel() {
+ return false;
+ }
+
+ public Clustering getMicroClusteringResult() {
+ return null;
+ };
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/ClusterGenerator.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/ClusterGenerator.java
index 6056514..6e01f86 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/ClusterGenerator.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/ClusterGenerator.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.clusterers;
/*
@@ -32,337 +31,326 @@
import com.yahoo.labs.samoa.moa.core.DataPoint;
import com.yahoo.labs.samoa.instances.Instance;
-public class ClusterGenerator extends AbstractClusterer{
+public class ClusterGenerator extends AbstractClusterer {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public IntOption timeWindowOption = new IntOption("timeWindow",
- 't', "Rang of the window.", 1000);
+ public IntOption timeWindowOption = new IntOption("timeWindow",
+ 't', "Rang of the window.", 1000);
- public FloatOption radiusDecreaseOption = new FloatOption("radiusDecrease", 'r',
- "The average radii of the centroids in the model.", 0, 0, 1);
+ public FloatOption radiusDecreaseOption = new FloatOption("radiusDecrease", 'r',
+ "The average radii of the centroids in the model.", 0, 0, 1);
- public FloatOption radiusIncreaseOption = new FloatOption("radiusIncrease", 'R',
- "The average radii of the centroids in the model.", 0, 0, 1);
+ public FloatOption radiusIncreaseOption = new FloatOption("radiusIncrease", 'R',
+ "The average radii of the centroids in the model.", 0, 0, 1);
- public FloatOption positionOffsetOption = new FloatOption("positionOffset", 'p',
- "The average radii of the centroids in the model.", 0, 0, 1);
+ public FloatOption positionOffsetOption = new FloatOption("positionOffset", 'p',
+ "The average radii of the centroids in the model.", 0, 0, 1);
- public FloatOption clusterRemoveOption = new FloatOption("clusterRemove", 'D',
- "Deletes complete clusters from the clustering.", 0, 0, 1);
+ public FloatOption clusterRemoveOption = new FloatOption("clusterRemove", 'D',
+ "Deletes complete clusters from the clustering.", 0, 0, 1);
- public FloatOption joinClustersOption = new FloatOption("joinClusters", 'j',
- "Join two clusters if their hull distance is less minRadius times this factor.", 0, 0, 1);
+ public FloatOption joinClustersOption = new FloatOption("joinClusters", 'j',
+ "Join two clusters if their hull distance is less minRadius times this factor.", 0, 0, 1);
- public FloatOption clusterAddOption = new FloatOption("clusterAdd", 'A',
- "Adds additional clusters.", 0, 0, 1);
+ public FloatOption clusterAddOption = new FloatOption("clusterAdd", 'A',
+ "Adds additional clusters.", 0, 0, 1);
- private static double err_intervall_width = 0.0;
- private ArrayList<DataPoint> points;
- private int instanceCounter;
- private int windowCounter;
- private Random random;
- private Clustering sourceClustering = null;
+ private static double err_intervall_width = 0.0;
+ private ArrayList<DataPoint> points;
+ private int instanceCounter;
+ private int windowCounter;
+ private Random random;
+ private Clustering sourceClustering = null;
- @Override
- public void resetLearningImpl() {
- points = new ArrayList<DataPoint>();
- instanceCounter = 0;
- windowCounter = 0;
- random = new Random(227);
+ @Override
+ public void resetLearningImpl() {
+ points = new ArrayList<DataPoint>();
+ instanceCounter = 0;
+ windowCounter = 0;
+ random = new Random(227);
- //joinClustersOption.set();
- //evaluateMicroClusteringOption.set();
+ // joinClustersOption.set();
+ // evaluateMicroClusteringOption.set();
+ }
+
+ @Override
+ public void trainOnInstanceImpl(Instance inst) {
+ if (windowCounter >= timeWindowOption.getValue()) {
+ points.clear();
+ windowCounter = 0;
+ }
+ windowCounter++;
+ instanceCounter++;
+ points.add(new DataPoint(inst, instanceCounter));
+ }
+
+ @Override
+ public boolean implementsMicroClusterer() {
+ return true;
+ }
+
+ public void setSourceClustering(Clustering source) {
+ sourceClustering = source;
+ }
+
+ @Override
+ public Clustering getMicroClusteringResult() {
+ // System.out.println("Numcluster:"+clustering.size()+" / "+num);
+ // Clustering source_clustering = new Clustering(points, overlapThreshold,
+ // microInitMinPoints);
+ if (sourceClustering == null) {
+
+ System.out.println("You need to set a source clustering for the ClusterGenerator to work");
+ return null;
+ }
+ return alterClustering(sourceClustering);
+ }
+
+ public Clustering getClusteringResult() {
+ sourceClustering = new Clustering(points);
+ // if(sourceClustering == null){
+ // System.out.println("You need to set a source clustering for the ClusterGenerator to work");
+ // return null;
+ // }
+ return alterClustering(sourceClustering);
+ }
+
+ private Clustering alterClustering(Clustering scclustering) {
+ // percentage of the radius that will be cut off
+ // 0: no changes to radius
+ // 1: radius of 0
+ double errLevelRadiusDecrease = radiusDecreaseOption.getValue();
+
+ // 0: no changes to radius
+ // 1: radius 100% bigger
+ double errLevelRadiusIncrease = radiusIncreaseOption.getValue();
+
+ // 0: no changes
+ // 1: distance between centers is 2 * original radius
+ double errLevelPosition = positionOffsetOption.getValue();
+
+ int numRemoveCluster = (int) (clusterRemoveOption.getValue() * scclustering.size());
+
+ int numAddCluster = (int) (clusterAddOption.getValue() * scclustering.size());
+
+ for (int c = 0; c < numRemoveCluster; c++) {
+ int delId = random.nextInt(scclustering.size());
+ scclustering.remove(delId);
}
- @Override
- public void trainOnInstanceImpl(Instance inst) {
- if(windowCounter >= timeWindowOption.getValue()){
- points.clear();
- windowCounter = 0;
+ int numCluster = scclustering.size();
+ double[] err_seeds = new double[numCluster];
+ double err_seed_sum = 0.0;
+ double tmp_seed;
+ for (int i = 0; i < numCluster; i++) {
+ tmp_seed = random.nextDouble();
+ err_seeds[i] = err_seed_sum + tmp_seed;
+ err_seed_sum += tmp_seed;
+ }
+
+ double sumWeight = 0;
+ for (int i = 0; i < numCluster; i++) {
+ sumWeight += scclustering.get(i).getWeight();
+ }
+
+ Clustering clustering = new Clustering();
+
+ for (int i = 0; i < numCluster; i++) {
+ if (!(scclustering.get(i) instanceof SphereCluster)) {
+ System.out.println("Not a Sphere Cluster");
+ continue;
+ }
+ SphereCluster sourceCluster = (SphereCluster) scclustering.get(i);
+ double[] center = Arrays.copyOf(sourceCluster.getCenter(), sourceCluster.getCenter().length);
+ double weight = sourceCluster.getWeight();
+ double radius = sourceCluster.getRadius();
+
+ // move cluster center
+ if (errLevelPosition > 0) {
+ double errOffset = random.nextDouble() * err_intervall_width / 2.0;
+ double errOffsetDirection = ((random.nextBoolean()) ? 1 : -1);
+ double level = errLevelPosition + errOffsetDirection * errOffset;
+ double[] vector = new double[center.length];
+ double vectorLength = 0;
+ for (int d = 0; d < center.length; d++) {
+ vector[d] = (random.nextBoolean() ? 1 : -1) * random.nextDouble();
+ vectorLength += Math.pow(vector[d], 2);
}
- windowCounter++;
- instanceCounter++;
- points.add( new DataPoint(inst,instanceCounter));
- }
+ vectorLength = Math.sqrt(vectorLength);
- @Override
- public boolean implementsMicroClusterer() {
- return true;
- }
+ // max is when clusters are next to each other
+ double length = 2 * radius * level;
-
- public void setSourceClustering(Clustering source){
- sourceClustering = source;
- }
-
- @Override
- public Clustering getMicroClusteringResult() {
- //System.out.println("Numcluster:"+clustering.size()+" / "+num);
- //Clustering source_clustering = new Clustering(points, overlapThreshold, microInitMinPoints);
- if(sourceClustering == null){
-
- System.out.println("You need to set a source clustering for the ClusterGenerator to work");
- return null;
+ for (int d = 0; d < center.length; d++) {
+ // normalize length and then strecht to reach error position
+ vector[d] = vector[d] / vectorLength * length;
}
- return alterClustering(sourceClustering);
- }
-
-
-
- public Clustering getClusteringResult(){
- sourceClustering = new Clustering(points);
-// if(sourceClustering == null){
-// System.out.println("You need to set a source clustering for the ClusterGenerator to work");
-// return null;
-// }
- return alterClustering(sourceClustering);
- }
-
-
- private Clustering alterClustering(Clustering scclustering){
- //percentage of the radius that will be cut off
- //0: no changes to radius
- //1: radius of 0
- double errLevelRadiusDecrease = radiusDecreaseOption.getValue();
-
- //0: no changes to radius
- //1: radius 100% bigger
- double errLevelRadiusIncrease = radiusIncreaseOption.getValue();
-
- //0: no changes
- //1: distance between centers is 2 * original radius
- double errLevelPosition = positionOffsetOption.getValue();
-
-
- int numRemoveCluster = (int)(clusterRemoveOption.getValue()*scclustering.size());
-
- int numAddCluster = (int)(clusterAddOption.getValue()*scclustering.size());
-
- for (int c = 0; c < numRemoveCluster; c++) {
- int delId = random.nextInt(scclustering.size());
- scclustering.remove(delId);
+ // System.out.println("Center "+Arrays.toString(center));
+ // System.out.println("Vector "+Arrays.toString(vector));
+ // check if error position is within bounds
+ double[] newCenter = new double[center.length];
+ for (int d = 0; d < center.length; d++) {
+ // check bounds, otherwise flip vector
+ if (center[d] + vector[d] >= 0 && center[d] + vector[d] <= 1) {
+ newCenter[d] = center[d] + vector[d];
+ }
+ else {
+ newCenter[d] = center[d] + (-1) * vector[d];
+ }
}
-
- int numCluster = scclustering.size();
- double[] err_seeds = new double[numCluster];
- double err_seed_sum = 0.0;
- double tmp_seed;
- for (int i = 0; i < numCluster; i++) {
- tmp_seed = random.nextDouble();
- err_seeds[i] = err_seed_sum + tmp_seed;
- err_seed_sum+= tmp_seed;
+ center = newCenter;
+ for (int d = 0; d < center.length; d++) {
+ if (newCenter[d] >= 0 && newCenter[d] <= 1) {
+ }
+ else {
+ System.out
+ .println("This shouldnt have happend, Cluster center out of bounds:" + Arrays.toString(newCenter));
+ }
}
+ // System.out.println("new Center "+Arrays.toString(newCenter));
- double sumWeight = 0;
- for (int i = 0; i <numCluster; i++) {
- sumWeight+= scclustering.get(i).getWeight();
+ }
+
+ // alter radius
+ if (errLevelRadiusDecrease > 0 || errLevelRadiusIncrease > 0) {
+ double errOffset = random.nextDouble() * err_intervall_width / 2.0;
+ int errOffsetDirection = ((random.nextBoolean()) ? 1 : -1);
+
+ if (errLevelRadiusDecrease > 0 && (errLevelRadiusIncrease == 0 || random.nextBoolean())) {
+ double level = (errLevelRadiusDecrease + errOffsetDirection * errOffset);// *sourceCluster.getWeight()/sumWeight;
+ level = (level < 0) ? 0 : level;
+ level = (level > 1) ? 1 : level;
+ radius *= (1 - level);
}
-
- Clustering clustering = new Clustering();
-
- for (int i = 0; i <numCluster; i++) {
- if(!(scclustering.get(i) instanceof SphereCluster)){
- System.out.println("Not a Sphere Cluster");
- continue;
- }
- SphereCluster sourceCluster = (SphereCluster)scclustering.get(i);
- double[] center = Arrays.copyOf(sourceCluster.getCenter(),sourceCluster.getCenter().length);
- double weight = sourceCluster.getWeight();
- double radius = sourceCluster.getRadius();
-
- //move cluster center
- if(errLevelPosition >0){
- double errOffset = random.nextDouble()*err_intervall_width/2.0;
- double errOffsetDirection = ((random.nextBoolean())? 1 : -1);
- double level = errLevelPosition + errOffsetDirection * errOffset;
- double[] vector = new double[center.length];
- double vectorLength = 0;
- for (int d = 0; d < center.length; d++) {
- vector[d] = (random.nextBoolean()?1:-1)*random.nextDouble();
- vectorLength += Math.pow(vector[d],2);
- }
- vectorLength = Math.sqrt(vectorLength);
-
-
- //max is when clusters are next to each other
- double length = 2 * radius * level;
-
- for (int d = 0; d < center.length; d++) {
- //normalize length and then strecht to reach error position
- vector[d]=vector[d]/vectorLength*length;
- }
-// System.out.println("Center "+Arrays.toString(center));
-// System.out.println("Vector "+Arrays.toString(vector));
- //check if error position is within bounds
- double [] newCenter = new double[center.length];
- for (int d = 0; d < center.length; d++) {
- //check bounds, otherwise flip vector
- if(center[d] + vector[d] >= 0 && center[d] + vector[d] <= 1){
- newCenter[d] = center[d] + vector[d];
- }
- else{
- newCenter[d] = center[d] + (-1)*vector[d];
- }
- }
- center = newCenter;
- for (int d = 0; d < center.length; d++) {
- if(newCenter[d] >= 0 && newCenter[d] <= 1){
- }
- else{
- System.out.println("This shouldnt have happend, Cluster center out of bounds:"+Arrays.toString(newCenter));
- }
- }
- //System.out.println("new Center "+Arrays.toString(newCenter));
-
- }
-
- //alter radius
- if(errLevelRadiusDecrease > 0 || errLevelRadiusIncrease > 0){
- double errOffset = random.nextDouble()*err_intervall_width/2.0;
- int errOffsetDirection = ((random.nextBoolean())? 1 : -1);
-
- if(errLevelRadiusDecrease > 0 && (errLevelRadiusIncrease == 0 || random.nextBoolean())){
- double level = (errLevelRadiusDecrease + errOffsetDirection * errOffset);//*sourceCluster.getWeight()/sumWeight;
- level = (level<0)?0:level;
- level = (level>1)?1:level;
- radius*=(1-level);
- }
- else{
- double level = errLevelRadiusIncrease + errOffsetDirection * errOffset;
- level = (level<0)?0:level;
- level = (level>1)?1:level;
- radius+=radius*level;
- }
- }
-
- SphereCluster newCluster = new SphereCluster(center, radius, weight);
- newCluster.setMeasureValue("Source Cluster", "C"+sourceCluster.getId());
-
- clustering.add(newCluster);
+ else {
+ double level = errLevelRadiusIncrease + errOffsetDirection * errOffset;
+ level = (level < 0) ? 0 : level;
+ level = (level > 1) ? 1 : level;
+ radius += radius * level;
}
+ }
- if(joinClustersOption.getValue() > 0){
- clustering = joinClusters(clustering);
- }
+ SphereCluster newCluster = new SphereCluster(center, radius, weight);
+ newCluster.setMeasureValue("Source Cluster", "C" + sourceCluster.getId());
- //add new clusters by copying clusters and set a random center
- for (int c = 0; c < numAddCluster; c++) {
- int copyId = random.nextInt(clustering.size());
- SphereCluster scorg = (SphereCluster)clustering.get(copyId);
- int dim = scorg.getCenter().length;
- double[] center = new double [dim];
- double radius = scorg.getRadius();
-
- boolean outofbounds = true;
- int tryCounter = 0;
- while(outofbounds && tryCounter < 20){
- tryCounter++;
- outofbounds = false;
- for (int j = 0; j < center.length; j++) {
- center[j] = random.nextDouble();
- if(center[j]- radius < 0 || center[j] + radius > 1){
- outofbounds = true;
- break;
- }
- }
- }
- if(outofbounds){
- System.out.println("Coludn't place additional cluster");
- }
- else{
- SphereCluster scnew = new SphereCluster(center, radius, scorg.getWeight()/2);
- scorg.setWeight(scorg.getWeight()-scnew.getWeight());
- clustering.add(scnew);
- }
- }
-
- return clustering;
-
+ clustering.add(newCluster);
}
+ if (joinClustersOption.getValue() > 0) {
+ clustering = joinClusters(clustering);
+ }
+ // add new clusters by copying clusters and set a random center
+ for (int c = 0; c < numAddCluster; c++) {
+ int copyId = random.nextInt(clustering.size());
+ SphereCluster scorg = (SphereCluster) clustering.get(copyId);
+ int dim = scorg.getCenter().length;
+ double[] center = new double[dim];
+ double radius = scorg.getRadius();
- private Clustering joinClusters(Clustering clustering){
-
- double radiusFactor = joinClustersOption.getValue();
- boolean[] merged = new boolean[clustering.size()];
-
- Clustering mclustering = new Clustering();
-
- if(radiusFactor >0){
- for (int c1 = 0; c1 < clustering.size(); c1++) {
- SphereCluster sc1 = (SphereCluster) clustering.get(c1);
- double minDist = Double.MAX_VALUE;
- double minOver = 1;
- int maxindexCon = -1;
- int maxindexOver = -1;
- for (int c2 = 0; c2 < clustering.size(); c2++) {
- SphereCluster sc2 = (SphereCluster) clustering.get(c2);
-// double over = sc1.overlapRadiusDegree(sc2);
-// if(over > 0 && over < minOver){
-// minOver = over;
-// maxindexOver = c2;
-// }
- double dist = sc1.getHullDistance(sc2);
- double threshold = Math.min(sc1.getRadius(), sc2.getRadius())*radiusFactor;
- if(dist > 0 && dist < minDist && dist < threshold){
- minDist = dist;
- maxindexCon = c2;
- }
- }
- int maxindex = -1;
- if(maxindexOver!=-1)
- maxindex = maxindexOver;
- else
- maxindex = maxindexCon;
-
- if(maxindex!=-1 && !merged[c1]){
- merged[c1]=true;
- merged[maxindex]=true;
- SphereCluster scnew = new SphereCluster(sc1.getCenter(),sc1.getRadius(),sc1.getWeight());
- SphereCluster sc2 = (SphereCluster) clustering.get(maxindex);
- scnew.merge(sc2);
- mclustering.add(scnew);
- }
- }
+ boolean outofbounds = true;
+ int tryCounter = 0;
+ while (outofbounds && tryCounter < 20) {
+ tryCounter++;
+ outofbounds = false;
+ for (int j = 0; j < center.length; j++) {
+ center[j] = random.nextDouble();
+ if (center[j] - radius < 0 || center[j] + radius > 1) {
+ outofbounds = true;
+ break;
+ }
}
+ }
+ if (outofbounds) {
+ System.out.println("Coludn't place additional cluster");
+ }
+ else {
+ SphereCluster scnew = new SphereCluster(center, radius, scorg.getWeight() / 2);
+ scorg.setWeight(scorg.getWeight() - scnew.getWeight());
+ clustering.add(scnew);
+ }
+ }
- for (int i = 0; i < merged.length; i++) {
- if(!merged[i])
- mclustering.add(clustering.get(i));
+ return clustering;
+
+ }
+
+ private Clustering joinClusters(Clustering clustering) {
+
+ double radiusFactor = joinClustersOption.getValue();
+ boolean[] merged = new boolean[clustering.size()];
+
+ Clustering mclustering = new Clustering();
+
+ if (radiusFactor > 0) {
+ for (int c1 = 0; c1 < clustering.size(); c1++) {
+ SphereCluster sc1 = (SphereCluster) clustering.get(c1);
+ double minDist = Double.MAX_VALUE;
+ double minOver = 1;
+ int maxindexCon = -1;
+ int maxindexOver = -1;
+ for (int c2 = 0; c2 < clustering.size(); c2++) {
+ SphereCluster sc2 = (SphereCluster) clustering.get(c2);
+ // double over = sc1.overlapRadiusDegree(sc2);
+ // if(over > 0 && over < minOver){
+ // minOver = over;
+ // maxindexOver = c2;
+ // }
+ double dist = sc1.getHullDistance(sc2);
+ double threshold = Math.min(sc1.getRadius(), sc2.getRadius()) * radiusFactor;
+ if (dist > 0 && dist < minDist && dist < threshold) {
+ minDist = dist;
+ maxindexCon = c2;
+ }
}
+ int maxindex = -1;
+ if (maxindexOver != -1)
+ maxindex = maxindexOver;
+ else
+ maxindex = maxindexCon;
-
- return mclustering;
-
+ if (maxindex != -1 && !merged[c1]) {
+ merged[c1] = true;
+ merged[maxindex] = true;
+ SphereCluster scnew = new SphereCluster(sc1.getCenter(), sc1.getRadius(), sc1.getWeight());
+ SphereCluster sc2 = (SphereCluster) clustering.get(maxindex);
+ scnew.merge(sc2);
+ mclustering.add(scnew);
+ }
+ }
}
-
-
- @Override
- protected Measurement[] getModelMeasurementsImpl() {
- throw new UnsupportedOperationException("Not supported yet.");
+ for (int i = 0; i < merged.length; i++) {
+ if (!merged[i])
+ mclustering.add(clustering.get(i));
}
- @Override
- public void getModelDescription(StringBuilder out, int indent) {
- throw new UnsupportedOperationException("Not supported yet.");
- }
+ return mclustering;
- @Override
- public boolean isRandomizable() {
- return false;
- }
+ }
- @Override
- public boolean keepClassLabel(){
- return true;
- }
+ @Override
+ protected Measurement[] getModelMeasurementsImpl() {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
- public double[] getVotesForInstance(Instance inst) {
- return null;
- }
+ @Override
+ public void getModelDescription(StringBuilder out, int indent) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
+
+ @Override
+ public boolean isRandomizable() {
+ return false;
+ }
+
+ @Override
+ public boolean keepClassLabel() {
+ return true;
+ }
+
+ public double[] getVotesForInstance(Instance inst) {
+ return null;
+ }
}
-
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/Clusterer.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/Clusterer.java
index 76d7d50..480da5f 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/Clusterer.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/Clusterer.java
@@ -29,36 +29,36 @@
public interface Clusterer extends MOAObject, OptionHandler {
- public void setModelContext(InstancesHeader ih);
+ public void setModelContext(InstancesHeader ih);
- public InstancesHeader getModelContext();
+ public InstancesHeader getModelContext();
- public boolean isRandomizable();
+ public boolean isRandomizable();
- public void setRandomSeed(int s);
+ public void setRandomSeed(int s);
- public boolean trainingHasStarted();
+ public boolean trainingHasStarted();
- public double trainingWeightSeenByModel();
+ public double trainingWeightSeenByModel();
- public void resetLearning();
+ public void resetLearning();
- public void trainOnInstance(Instance inst);
+ public void trainOnInstance(Instance inst);
- public double[] getVotesForInstance(Instance inst);
+ public double[] getVotesForInstance(Instance inst);
- public Measurement[] getModelMeasurements();
+ public Measurement[] getModelMeasurements();
- public Clusterer[] getSubClusterers();
+ public Clusterer[] getSubClusterers();
- public Clusterer copy();
+ public Clusterer copy();
- public Clustering getClusteringResult();
+ public Clustering getClusteringResult();
- public boolean implementsMicroClusterer();
+ public boolean implementsMicroClusterer();
- public Clustering getMicroClusteringResult();
-
- public boolean keepClassLabel();
+ public Clustering getMicroClusteringResult();
+
+ public boolean keepClassLabel();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/KMeans.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/KMeans.java
index a9a891f..aa98307 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/KMeans.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/KMeans.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.clusterers;
/*
@@ -29,174 +28,173 @@
import com.yahoo.labs.samoa.moa.cluster.SphereCluster;
/**
- * A kMeans implementation for microclusterings. For now it only uses the real centers of the
- * groundtruthclustering for implementation. There should also be an option to use random
- * centers.
- * TODO: random centers
- * TODO: Create a macro clustering interface to make different macro clustering algorithms available
- * to micro clustering algorithms like clustream, denstream and clustree
- *
+ * A kMeans implementation for microclusterings. For now it only uses the real
+ * centers of the groundtruthclustering for implementation. There should also be
+ * an option to use random centers. TODO: random centers TODO: Create a macro
+ * clustering interface to make different macro clustering algorithms available
+ * to micro clustering algorithms like clustream, denstream and clustree
+ *
*/
public class KMeans {
- /**
- * This kMeans implementation clusters a big number of microclusters
- * into a smaller amount of macro clusters. To make it comparable to other
- * algorithms it uses the real centers of the ground truth macro clustering
- * to have the best possible initialization. The quality of resulting
- * macro clustering yields an upper bound for kMeans on the underlying
- * microclustering.
- *
- * @param centers of the ground truth clustering
- * @param data list of microclusters
- * @return
- */
- public static Clustering kMeans(Cluster[] centers, List<? extends Cluster> data ) {
- int k = centers.length;
+ /**
+ * This kMeans implementation clusters a big number of microclusters into a
+ * smaller amount of macro clusters. To make it comparable to other algorithms
+ * it uses the real centers of the ground truth macro clustering to have the
+ * best possible initialization. The quality of resulting macro clustering
+ * yields an upper bound for kMeans on the underlying microclustering.
+ *
+ * @param centers
+ * of the ground truth clustering
+ * @param data
+ * list of microclusters
+ * @return
+ */
+ public static Clustering kMeans(Cluster[] centers, List<? extends Cluster> data) {
+ int k = centers.length;
- int dimensions = centers[0].getCenter().length;
+ int dimensions = centers[0].getCenter().length;
- ArrayList<ArrayList<Cluster>> clustering =
- new ArrayList<ArrayList<Cluster>>();
- for ( int i = 0; i < k; i++ ) {
- clustering.add( new ArrayList<Cluster>() );
- }
-
- int repetitions = 100;
- while ( repetitions-- >= 0 ) {
- // Assign points to clusters
- for ( Cluster point : data ) {
- double minDistance = distance( point.getCenter(), centers[0].getCenter() );
- int closestCluster = 0;
- for ( int i = 1; i < k; i++ ) {
- double distance = distance( point.getCenter(), centers[i].getCenter() );
- if ( distance < minDistance ) {
- closestCluster = i;
- minDistance = distance;
- }
- }
-
- clustering.get( closestCluster ).add( point );
- }
-
- // Calculate new centers and clear clustering lists
- SphereCluster[] newCenters = new SphereCluster[centers.length];
- for ( int i = 0; i < k; i++ ) {
- newCenters[i] = calculateCenter( clustering.get( i ), dimensions );
- clustering.get( i ).clear();
- }
- centers = newCenters;
- }
-
- return new Clustering( centers );
+ ArrayList<ArrayList<Cluster>> clustering =
+ new ArrayList<ArrayList<Cluster>>();
+ for (int i = 0; i < k; i++) {
+ clustering.add(new ArrayList<Cluster>());
}
- private static double distance(double[] pointA, double [] pointB){
- double distance = 0.0;
- for (int i = 0; i < pointA.length; i++) {
- double d = pointA[i] - pointB[i];
- distance += d * d;
- }
- return Math.sqrt(distance);
- }
-
-
- private static SphereCluster calculateCenter( ArrayList<Cluster> cluster, int dimensions ) {
- double[] res = new double[dimensions];
- for ( int i = 0; i < res.length; i++ ) {
- res[i] = 0.0;
- }
-
- if ( cluster.size() == 0 ) {
- return new SphereCluster( res, 0.0 );
- }
-
- for ( Cluster point : cluster ) {
- double [] center = point.getCenter();
- for (int i = 0; i < res.length; i++) {
- res[i] += center[i];
- }
- }
-
- // Normalize
- for ( int i = 0; i < res.length; i++ ) {
- res[i] /= cluster.size();
- }
-
- // Calculate radius
- double radius = 0.0;
- for ( Cluster point : cluster ) {
- double dist = distance( res, point.getCenter() );
- if ( dist > radius ) {
- radius = dist;
- }
- }
-
- return new SphereCluster( res, radius );
- }
-
- public static Clustering gaussianMeans(Clustering gtClustering, Clustering clustering) {
- ArrayList<CFCluster> microclusters = new ArrayList<CFCluster>();
- for (int i = 0; i < clustering.size(); i++) {
- if (clustering.get(i) instanceof CFCluster) {
- microclusters.add((CFCluster)clustering.get(i));
- }
- else{
- System.out.println("Unsupported Cluster Type:"+clustering.get(i).getClass()
- +". Cluster needs to extend moa.cluster.CFCluster");
- }
- }
- Cluster[] centers = new Cluster[gtClustering.size()];
- for (int i = 0; i < centers.length; i++) {
- centers[i] = gtClustering.get(i);
-
+ int repetitions = 100;
+ while (repetitions-- >= 0) {
+ // Assign points to clusters
+ for (Cluster point : data) {
+ double minDistance = distance(point.getCenter(), centers[0].getCenter());
+ int closestCluster = 0;
+ for (int i = 1; i < k; i++) {
+ double distance = distance(point.getCenter(), centers[i].getCenter());
+ if (distance < minDistance) {
+ closestCluster = i;
+ minDistance = distance;
+ }
}
- int k = centers.length;
- if ( microclusters.size() < k ) {
- return new Clustering( new Cluster[0]);
- }
+ clustering.get(closestCluster).add(point);
+ }
- Clustering kMeansResult = kMeans( centers, microclusters );
-
- k = kMeansResult.size();
- CFCluster[] res = new CFCluster[ k ];
-
- for ( CFCluster microcluster : microclusters) {
- // Find closest kMeans cluster
- double minDistance = Double.MAX_VALUE;
- int closestCluster = 0;
- for ( int i = 0; i < k; i++ ) {
- double distance = distance( kMeansResult.get(i).getCenter(), microcluster.getCenter() );
- if ( distance < minDistance ) {
- closestCluster = i;
- minDistance = distance;
- }
- }
-
- // Add to cluster
- if ( res[closestCluster] == null ) {
- res[closestCluster] = (CFCluster)microcluster.copy();
- } else {
- res[closestCluster].add(microcluster);
- }
- }
-
- // Clean up res
- int count = 0;
- for ( int i = 0; i < res.length; i++ ) {
- if ( res[i] != null )
- ++count;
- }
-
- CFCluster[] cleaned = new CFCluster[count];
- count = 0;
- for ( int i = 0; i < res.length; i++ ) {
- if ( res[i] != null )
- cleaned[count++] = res[i];
- }
-
- return new Clustering( cleaned );
+ // Calculate new centers and clear clustering lists
+ SphereCluster[] newCenters = new SphereCluster[centers.length];
+ for (int i = 0; i < k; i++) {
+ newCenters[i] = calculateCenter(clustering.get(i), dimensions);
+ clustering.get(i).clear();
+ }
+ centers = newCenters;
}
+ return new Clustering(centers);
+ }
+
+ private static double distance(double[] pointA, double[] pointB) {
+ double distance = 0.0;
+ for (int i = 0; i < pointA.length; i++) {
+ double d = pointA[i] - pointB[i];
+ distance += d * d;
+ }
+ return Math.sqrt(distance);
+ }
+
+ private static SphereCluster calculateCenter(ArrayList<Cluster> cluster, int dimensions) {
+ double[] res = new double[dimensions];
+ for (int i = 0; i < res.length; i++) {
+ res[i] = 0.0;
+ }
+
+ if (cluster.size() == 0) {
+ return new SphereCluster(res, 0.0);
+ }
+
+ for (Cluster point : cluster) {
+ double[] center = point.getCenter();
+ for (int i = 0; i < res.length; i++) {
+ res[i] += center[i];
+ }
+ }
+
+ // Normalize
+ for (int i = 0; i < res.length; i++) {
+ res[i] /= cluster.size();
+ }
+
+ // Calculate radius
+ double radius = 0.0;
+ for (Cluster point : cluster) {
+ double dist = distance(res, point.getCenter());
+ if (dist > radius) {
+ radius = dist;
+ }
+ }
+
+ return new SphereCluster(res, radius);
+ }
+
+ public static Clustering gaussianMeans(Clustering gtClustering, Clustering clustering) {
+ ArrayList<CFCluster> microclusters = new ArrayList<CFCluster>();
+ for (int i = 0; i < clustering.size(); i++) {
+ if (clustering.get(i) instanceof CFCluster) {
+ microclusters.add((CFCluster) clustering.get(i));
+ }
+ else {
+ System.out.println("Unsupported Cluster Type:" + clustering.get(i).getClass()
+ + ". Cluster needs to extend moa.cluster.CFCluster");
+ }
+ }
+ Cluster[] centers = new Cluster[gtClustering.size()];
+ for (int i = 0; i < centers.length; i++) {
+ centers[i] = gtClustering.get(i);
+
+ }
+
+ int k = centers.length;
+ if (microclusters.size() < k) {
+ return new Clustering(new Cluster[0]);
+ }
+
+ Clustering kMeansResult = kMeans(centers, microclusters);
+
+ k = kMeansResult.size();
+ CFCluster[] res = new CFCluster[k];
+
+ for (CFCluster microcluster : microclusters) {
+ // Find closest kMeans cluster
+ double minDistance = Double.MAX_VALUE;
+ int closestCluster = 0;
+ for (int i = 0; i < k; i++) {
+ double distance = distance(kMeansResult.get(i).getCenter(), microcluster.getCenter());
+ if (distance < minDistance) {
+ closestCluster = i;
+ minDistance = distance;
+ }
+ }
+
+ // Add to cluster
+ if (res[closestCluster] == null) {
+ res[closestCluster] = (CFCluster) microcluster.copy();
+ } else {
+ res[closestCluster].add(microcluster);
+ }
+ }
+
+ // Clean up res
+ int count = 0;
+ for (int i = 0; i < res.length; i++) {
+ if (res[i] != null)
+ ++count;
+ }
+
+ CFCluster[] cleaned = new CFCluster[count];
+ count = 0;
+ for (int i = 0; i < res.length; i++) {
+ if (res[i] != null)
+ cleaned[count++] = res[i];
+ }
+
+ return new Clustering(cleaned);
+ }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/Clustream.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/Clustream.java
index 975e61d..c329e5b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/Clustream.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/Clustream.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.clusterers.clustream;
/*
@@ -34,304 +33,300 @@
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
-/** Citation: CluStream: Charu C. Aggarwal, Jiawei Han, Jianyong Wang, Philip S. Yu:
- * A Framework for Clustering Evolving Data Streams. VLDB 2003: 81-92
+/**
+ * Citation: CluStream: Charu C. Aggarwal, Jiawei Han, Jianyong Wang, Philip S.
+ * Yu: A Framework for Clustering Evolving Data Streams. VLDB 2003: 81-92
*/
-public class Clustream extends AbstractClusterer{
+public class Clustream extends AbstractClusterer {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public IntOption timeWindowOption = new IntOption("horizon",
- 'h', "Rang of the window.", 1000);
+ public IntOption timeWindowOption = new IntOption("horizon",
+ 'h', "Rang of the window.", 1000);
- public IntOption maxNumKernelsOption = new IntOption(
- "maxNumKernels", 'k',
- "Maximum number of micro kernels to use.", 100);
+ public IntOption maxNumKernelsOption = new IntOption(
+ "maxNumKernels", 'k',
+ "Maximum number of micro kernels to use.", 100);
- public IntOption kernelRadiFactorOption = new IntOption(
- "kernelRadiFactor", 't',
- "Multiplier for the kernel radius", 2);
+ public IntOption kernelRadiFactorOption = new IntOption(
+ "kernelRadiFactor", 't',
+ "Multiplier for the kernel radius", 2);
- private int timeWindow;
- private long timestamp = -1;
- private ClustreamKernel[] kernels;
- private boolean initialized;
- private List<ClustreamKernel> buffer; // Buffer for initialization with kNN
- private int bufferSize;
- private double t;
- private int m;
+ private int timeWindow;
+ private long timestamp = -1;
+ private ClustreamKernel[] kernels;
+ private boolean initialized;
+ private List<ClustreamKernel> buffer; // Buffer for initialization with kNN
+ private int bufferSize;
+ private double t;
+ private int m;
- public Clustream() {
- }
+ public Clustream() {
+ }
+ @Override
+ public void resetLearningImpl() {
+ this.kernels = new ClustreamKernel[maxNumKernelsOption.getValue()];
+ this.timeWindow = timeWindowOption.getValue();
+ this.initialized = false;
+ this.buffer = new LinkedList<>();
+ this.bufferSize = maxNumKernelsOption.getValue();
+ t = kernelRadiFactorOption.getValue();
+ m = maxNumKernelsOption.getValue();
+ }
- @Override
- public void resetLearningImpl() {
- this.kernels = new ClustreamKernel[maxNumKernelsOption.getValue()];
- this.timeWindow = timeWindowOption.getValue();
- this.initialized = false;
- this.buffer = new LinkedList<>();
- this.bufferSize = maxNumKernelsOption.getValue();
- t = kernelRadiFactorOption.getValue();
- m = maxNumKernelsOption.getValue();
- }
+ @Override
+ public void trainOnInstanceImpl(Instance instance) {
+ int dim = instance.numValues();
+ timestamp++;
+ // 0. Initialize
+ if (!initialized) {
+ if (buffer.size() < bufferSize) {
+ buffer.add(new ClustreamKernel(instance, dim, timestamp, t, m));
+ return;
+ }
- @Override
- public void trainOnInstanceImpl(Instance instance) {
- int dim = instance.numValues();
- timestamp++;
- // 0. Initialize
- if ( !initialized ) {
- if ( buffer.size() < bufferSize ) {
- buffer.add( new ClustreamKernel(instance,dim, timestamp, t, m) );
- return;
- }
+ int k = kernels.length;
+ // System.err.println("k="+k+" bufferSize="+bufferSize);
+ assert (k <= bufferSize);
- int k = kernels.length;
- //System.err.println("k="+k+" bufferSize="+bufferSize);
- assert (k <= bufferSize);
+ ClustreamKernel[] centers = new ClustreamKernel[k];
+ for (int i = 0; i < k; i++) {
+ centers[i] = buffer.get(i); // TODO: make random!
+ }
+ Clustering kmeans_clustering = kMeans(k, centers, buffer);
+ // Clustering kmeans_clustering = kMeans(k, buffer);
- ClustreamKernel[] centers = new ClustreamKernel[k];
- for ( int i = 0; i < k; i++ ) {
- centers[i] = buffer.get( i ); // TODO: make random!
- }
- Clustering kmeans_clustering = kMeans(k, centers, buffer);
-// Clustering kmeans_clustering = kMeans(k, buffer);
+ for (int i = 0; i < kmeans_clustering.size(); i++) {
+ kernels[i] = new ClustreamKernel(new DenseInstance(1.0, centers[i].getCenter()), dim, timestamp, t, m);
+ }
- for ( int i = 0; i < kmeans_clustering.size(); i++ ) {
- kernels[i] = new ClustreamKernel( new DenseInstance(1.0,centers[i].getCenter()), dim, timestamp, t, m );
- }
+ buffer.clear();
+ initialized = true;
+ return;
+ }
- buffer.clear();
- initialized = true;
- return;
- }
+ // 1. Determine closest kernel
+ ClustreamKernel closestKernel = null;
+ double minDistance = Double.MAX_VALUE;
+ for (ClustreamKernel kernel : kernels) {
+ // System.out.println(i+" "+kernels[i].getWeight()+" "+kernels[i].getDeviation());
+ double distance = distance(instance.toDoubleArray(), kernel.getCenter());
+ if (distance < minDistance) {
+ closestKernel = kernel;
+ minDistance = distance;
+ }
+ }
+ // 2. Check whether instance fits into closestKernel
+ double radius;
+ if (closestKernel != null && closestKernel.getWeight() == 1) {
+ // Special case: estimate radius by determining the distance to the
+ // next closest cluster
+ radius = Double.MAX_VALUE;
+ double[] center = closestKernel.getCenter();
+ for (ClustreamKernel kernel : kernels) {
+ if (kernel == closestKernel) {
+ continue;
+ }
- // 1. Determine closest kernel
- ClustreamKernel closestKernel = null;
- double minDistance = Double.MAX_VALUE;
- for (ClustreamKernel kernel : kernels) {
- //System.out.println(i+" "+kernels[i].getWeight()+" "+kernels[i].getDeviation());
- double distance = distance(instance.toDoubleArray(), kernel.getCenter());
- if (distance < minDistance) {
- closestKernel = kernel;
- minDistance = distance;
- }
- }
+ double distance = distance(kernel.getCenter(), center);
+ radius = Math.min(distance, radius);
+ }
+ } else {
+ radius = closestKernel.getRadius();
+ }
- // 2. Check whether instance fits into closestKernel
- double radius;
- if (closestKernel != null && closestKernel.getWeight() == 1) {
- // Special case: estimate radius by determining the distance to the
- // next closest cluster
- radius = Double.MAX_VALUE;
- double[] center = closestKernel.getCenter();
- for (ClustreamKernel kernel : kernels) {
- if (kernel == closestKernel) {
- continue;
- }
+ if (minDistance < radius) {
+ // Date fits, put into kernel and be happy
+ closestKernel.insert(instance, timestamp);
+ return;
+ }
- double distance = distance(kernel.getCenter(), center);
- radius = Math.min(distance, radius);
- }
- } else {
- radius = closestKernel.getRadius();
- }
+ // 3. Date does not fit, we need to free
+ // some space to insert a new kernel
+ long threshold = timestamp - timeWindow; // Kernels before this can be
+ // forgotten
- if ( minDistance < radius ) {
- // Date fits, put into kernel and be happy
- closestKernel.insert( instance, timestamp );
- return;
- }
+ // 3.1 Try to forget old kernels
+ for (int i = 0; i < kernels.length; i++) {
+ if (kernels[i].getRelevanceStamp() < threshold) {
+ kernels[i] = new ClustreamKernel(instance, dim, timestamp, t, m);
+ return;
+ }
+ }
- // 3. Date does not fit, we need to free
- // some space to insert a new kernel
- long threshold = timestamp - timeWindow; // Kernels before this can be forgotten
+ // 3.2 Merge closest two kernels
+ int closestA = 0;
+ int closestB = 0;
+ minDistance = Double.MAX_VALUE;
+ for (int i = 0; i < kernels.length; i++) {
+ double[] centerA = kernels[i].getCenter();
+ for (int j = i + 1; j < kernels.length; j++) {
+ double dist = distance(centerA, kernels[j].getCenter());
+ if (dist < minDistance) {
+ minDistance = dist;
+ closestA = i;
+ closestB = j;
+ }
+ }
+ }
+ assert (closestA != closestB);
- // 3.1 Try to forget old kernels
- for ( int i = 0; i < kernels.length; i++ ) {
- if ( kernels[i].getRelevanceStamp() < threshold ) {
- kernels[i] = new ClustreamKernel( instance, dim, timestamp, t, m );
- return;
- }
- }
+ kernels[closestA].add(kernels[closestB]);
+ kernels[closestB] = new ClustreamKernel(instance, dim, timestamp, t, m);
+ }
- // 3.2 Merge closest two kernels
- int closestA = 0;
- int closestB = 0;
- minDistance = Double.MAX_VALUE;
- for ( int i = 0; i < kernels.length; i++ ) {
- double[] centerA = kernels[i].getCenter();
- for ( int j = i + 1; j < kernels.length; j++ ) {
- double dist = distance( centerA, kernels[j].getCenter() );
- if ( dist < minDistance ) {
- minDistance = dist;
- closestA = i;
- closestB = j;
- }
- }
- }
- assert (closestA != closestB);
+ @Override
+ public Clustering getMicroClusteringResult() {
+ if (!initialized) {
+ return new Clustering(new Cluster[0]);
+ }
- kernels[closestA].add( kernels[closestB] );
- kernels[closestB] = new ClustreamKernel( instance, dim, timestamp, t, m );
- }
+ ClustreamKernel[] res = new ClustreamKernel[kernels.length];
+ for (int i = 0; i < res.length; i++) {
+ res[i] = new ClustreamKernel(kernels[i], t, m);
+ }
- @Override
- public Clustering getMicroClusteringResult() {
- if ( !initialized ) {
- return new Clustering( new Cluster[0] );
- }
+ return new Clustering(res);
+ }
- ClustreamKernel[] res = new ClustreamKernel[kernels.length];
- for ( int i = 0; i < res.length; i++ ) {
- res[i] = new ClustreamKernel( kernels[i], t, m );
- }
+ @Override
+ public boolean implementsMicroClusterer() {
+ return true;
+ }
- return new Clustering( res );
- }
+ @Override
+ public Clustering getClusteringResult() {
+ return null;
+ }
- @Override
- public boolean implementsMicroClusterer() {
- return true;
- }
+ public String getName() {
+ return "Clustream " + timeWindow;
+ }
- @Override
- public Clustering getClusteringResult() {
- return null;
- }
+ private static double distance(double[] pointA, double[] pointB) {
+ double distance = 0.0;
+ for (int i = 0; i < pointA.length; i++) {
+ double d = pointA[i] - pointB[i];
+ distance += d * d;
+ }
+ return Math.sqrt(distance);
+ }
- public String getName() {
- return "Clustream " + timeWindow;
- }
+ // wrapper... we need to rewrite kmeans to points, not clusters, doesnt make
+ // sense anymore
+ // public static Clustering kMeans( int k, ArrayList<Instance> points, int dim
+ // ) {
+ // ArrayList<ClustreamKernel> cl = new ArrayList<ClustreamKernel>();
+ // for(Instance inst : points){
+ // cl.add(new ClustreamKernel(inst, dim , 0, 0, 0));
+ // }
+ // Clustering clustering = kMeans(k, cl);
+ // return clustering;
+ // }
- private static double distance(double[] pointA, double [] pointB){
- double distance = 0.0;
- for (int i = 0; i < pointA.length; i++) {
- double d = pointA[i] - pointB[i];
- distance += d * d;
- }
- return Math.sqrt(distance);
- }
+ public static Clustering kMeans(int k, List<? extends Cluster> data) {
+ Random random = new Random(0);
+ Cluster[] centers = new Cluster[k];
+ for (int i = 0; i < centers.length; i++) {
+ int rid = random.nextInt(k);
+ centers[i] = new SphereCluster(data.get(rid).getCenter(), 0);
+ }
+ return kMeans(k, centers, data);
+ }
- //wrapper... we need to rewrite kmeans to points, not clusters, doesnt make sense anymore
- // public static Clustering kMeans( int k, ArrayList<Instance> points, int dim ) {
- // ArrayList<ClustreamKernel> cl = new ArrayList<ClustreamKernel>();
- // for(Instance inst : points){
- // cl.add(new ClustreamKernel(inst, dim , 0, 0, 0));
- // }
- // Clustering clustering = kMeans(k, cl);
- // return clustering;
- // }
+ public static Clustering kMeans(int k, Cluster[] centers, List<? extends Cluster> data) {
+ assert (centers.length == k);
+ assert (k > 0);
- public static Clustering kMeans( int k, List<? extends Cluster> data ) {
- Random random = new Random(0);
- Cluster[] centers = new Cluster[k];
- for (int i = 0; i < centers.length; i++) {
- int rid = random.nextInt(k);
- centers[i] = new SphereCluster(data.get(rid).getCenter(),0);
- }
- return kMeans(k, centers, data);
- }
+ int dimensions = centers[0].getCenter().length;
+ ArrayList<ArrayList<Cluster>> clustering = new ArrayList<>();
+ for (int i = 0; i < k; i++) {
+ clustering.add(new ArrayList<Cluster>());
+ }
+ int repetitions = 100;
+ while (repetitions-- >= 0) {
+ // Assign points to clusters
+ for (Cluster point : data) {
+ double minDistance = distance(point.getCenter(), centers[0].getCenter());
+ int closestCluster = 0;
+ for (int i = 1; i < k; i++) {
+ double distance = distance(point.getCenter(), centers[i].getCenter());
+ if (distance < minDistance) {
+ closestCluster = i;
+ minDistance = distance;
+ }
+ }
+ clustering.get(closestCluster).add(point);
+ }
+ // Calculate new centers and clear clustering lists
+ SphereCluster[] newCenters = new SphereCluster[centers.length];
+ for (int i = 0; i < k; i++) {
+ newCenters[i] = calculateCenter(clustering.get(i), dimensions);
+ clustering.get(i).clear();
+ }
+ centers = newCenters;
+ }
- public static Clustering kMeans( int k, Cluster[] centers, List<? extends Cluster> data ) {
- assert (centers.length == k);
- assert (k > 0);
+ return new Clustering(centers);
+ }
- int dimensions = centers[0].getCenter().length;
+ private static SphereCluster calculateCenter(ArrayList<Cluster> cluster, int dimensions) {
+ double[] res = new double[dimensions];
+ for (int i = 0; i < res.length; i++) {
+ res[i] = 0.0;
+ }
- ArrayList<ArrayList<Cluster>> clustering = new ArrayList<>();
- for ( int i = 0; i < k; i++ ) {
- clustering.add( new ArrayList<Cluster>() );
- }
+ if (cluster.size() == 0) {
+ return new SphereCluster(res, 0.0);
+ }
- int repetitions = 100;
- while ( repetitions-- >= 0 ) {
- // Assign points to clusters
- for ( Cluster point : data ) {
- double minDistance = distance( point.getCenter(), centers[0].getCenter() );
- int closestCluster = 0;
- for ( int i = 1; i < k; i++ ) {
- double distance = distance( point.getCenter(), centers[i].getCenter() );
- if ( distance < minDistance ) {
- closestCluster = i;
- minDistance = distance;
- }
- }
+ for (Cluster point : cluster) {
+ double[] center = point.getCenter();
+ for (int i = 0; i < res.length; i++) {
+ res[i] += center[i];
+ }
+ }
- clustering.get( closestCluster ).add( point );
- }
+ // Normalize
+ for (int i = 0; i < res.length; i++) {
+ res[i] /= cluster.size();
+ }
- // Calculate new centers and clear clustering lists
- SphereCluster[] newCenters = new SphereCluster[centers.length];
- for ( int i = 0; i < k; i++ ) {
- newCenters[i] = calculateCenter( clustering.get( i ), dimensions );
- clustering.get( i ).clear();
- }
- centers = newCenters;
- }
+ // Calculate radius
+ double radius = 0.0;
+ for (Cluster point : cluster) {
+ double dist = distance(res, point.getCenter());
+ if (dist > radius) {
+ radius = dist;
+ }
+ }
+ SphereCluster sc = new SphereCluster(res, radius);
+ sc.setWeight(cluster.size());
+ return sc;
+ }
- return new Clustering( centers );
- }
+ @Override
+ protected Measurement[] getModelMeasurementsImpl() {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
- private static SphereCluster calculateCenter( ArrayList<Cluster> cluster, int dimensions ) {
- double[] res = new double[dimensions];
- for ( int i = 0; i < res.length; i++ ) {
- res[i] = 0.0;
- }
+ @Override
+ public void getModelDescription(StringBuilder out, int indent) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
- if ( cluster.size() == 0 ) {
- return new SphereCluster( res, 0.0 );
- }
+ public boolean isRandomizable() {
+ return false;
+ }
- for ( Cluster point : cluster ) {
- double [] center = point.getCenter();
- for (int i = 0; i < res.length; i++) {
- res[i] += center[i];
- }
- }
-
- // Normalize
- for ( int i = 0; i < res.length; i++ ) {
- res[i] /= cluster.size();
- }
-
- // Calculate radius
- double radius = 0.0;
- for ( Cluster point : cluster ) {
- double dist = distance( res, point.getCenter() );
- if ( dist > radius ) {
- radius = dist;
- }
- }
- SphereCluster sc = new SphereCluster( res, radius );
- sc.setWeight(cluster.size());
- return sc;
- }
-
- @Override
- protected Measurement[] getModelMeasurementsImpl() {
- throw new UnsupportedOperationException("Not supported yet.");
- }
-
- @Override
- public void getModelDescription(StringBuilder out, int indent) {
- throw new UnsupportedOperationException("Not supported yet.");
- }
-
- public boolean isRandomizable() {
- return false;
- }
-
- public double[] getVotesForInstance(Instance inst) {
- throw new UnsupportedOperationException("Not supported yet.");
- }
-
+ public double[] getVotesForInstance(Instance inst) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
}
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/ClustreamKernel.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/ClustreamKernel.java
index 8c8f6a0..d5123d2 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/ClustreamKernel.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/ClustreamKernel.java
@@ -25,245 +25,249 @@
import com.yahoo.labs.samoa.moa.cluster.CFCluster;
public class ClustreamKernel extends CFCluster {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- private final static double EPSILON = 0.00005;
- public static final double MIN_VARIANCE = 1e-50;
+ private final static double EPSILON = 0.00005;
+ public static final double MIN_VARIANCE = 1e-50;
- protected double LST;
- protected double SST;
+ protected double LST;
+ protected double SST;
- int m;
- double t;
+ int m;
+ double t;
+ public ClustreamKernel(Instance instance, int dimensions, long timestamp, double t, int m) {
+ super(instance, dimensions);
+ this.t = t;
+ this.m = m;
+ this.LST = timestamp;
+ this.SST = timestamp * timestamp;
+ }
- public ClustreamKernel( Instance instance, int dimensions, long timestamp , double t, int m) {
- super(instance, dimensions);
- this.t = t;
- this.m = m;
- this.LST = timestamp;
- this.SST = timestamp*timestamp;
+ public ClustreamKernel(ClustreamKernel cluster, double t, int m) {
+ super(cluster);
+ this.t = t;
+ this.m = m;
+ this.LST = cluster.LST;
+ this.SST = cluster.SST;
+ }
+
+ public void insert(Instance instance, long timestamp) {
+ N++;
+ LST += timestamp;
+ SST += timestamp * timestamp;
+
+ for (int i = 0; i < instance.numValues(); i++) {
+ LS[i] += instance.value(i);
+ SS[i] += instance.value(i) * instance.value(i);
}
+ }
- public ClustreamKernel( ClustreamKernel cluster, double t, int m ) {
- super(cluster);
- this.t = t;
- this.m = m;
- this.LST = cluster.LST;
- this.SST = cluster.SST;
+ @Override
+ public void add(CFCluster other2) {
+ ClustreamKernel other = (ClustreamKernel) other2;
+ assert (other.LS.length == this.LS.length);
+ this.N += other.N;
+ this.LST += other.LST;
+ this.SST += other.SST;
+
+ for (int i = 0; i < LS.length; i++) {
+ this.LS[i] += other.LS[i];
+ this.SS[i] += other.SS[i];
}
+ }
- public void insert( Instance instance, long timestamp ) {
- N++;
- LST += timestamp;
- SST += timestamp*timestamp;
-
- for ( int i = 0; i < instance.numValues(); i++ ) {
- LS[i] += instance.value(i);
- SS[i] += instance.value(i)*instance.value(i);
- }
+ public double getRelevanceStamp() {
+ if (N < 2 * m)
+ return getMuTime();
+
+ return getMuTime() + getSigmaTime() * getQuantile(((double) m) / (2 * N));
+ }
+
+ private double getMuTime() {
+ return LST / N;
+ }
+
+ private double getSigmaTime() {
+ return Math.sqrt(SST / N - (LST / N) * (LST / N));
+ }
+
+ private double getQuantile(double z) {
+ assert (z >= 0 && z <= 1);
+ return Math.sqrt(2) * inverseError(2 * z - 1);
+ }
+
+ @Override
+ public double getRadius() {
+ // trivial cluster
+ if (N == 1)
+ return 0;
+ if (t == 1)
+ t = 1;
+
+ return getDeviation() * radiusFactor;
+ }
+
+ @Override
+ public CFCluster getCF() {
+ return this;
+ }
+
+ private double getDeviation() {
+ double[] variance = getVarianceVector();
+ double sumOfDeviation = 0.0;
+ for (double aVariance : variance) {
+ double d = Math.sqrt(aVariance);
+ sumOfDeviation += d;
}
+ return sumOfDeviation / variance.length;
+ }
- @Override
- public void add( CFCluster other2 ) {
- ClustreamKernel other = (ClustreamKernel) other2;
- assert( other.LS.length == this.LS.length );
- this.N += other.N;
- this.LST += other.LST;
- this.SST += other.SST;
-
- for ( int i = 0; i < LS.length; i++ ) {
- this.LS[i] += other.LS[i];
- this.SS[i] += other.SS[i];
- }
+ /**
+ * @return this kernels' center
+ */
+ @Override
+ public double[] getCenter() {
+ assert (!this.isEmpty());
+ double res[] = new double[this.LS.length];
+ for (int i = 0; i < res.length; i++) {
+ res[i] = this.LS[i] / N;
}
+ return res;
+ }
- public double getRelevanceStamp() {
- if ( N < 2*m )
- return getMuTime();
-
- return getMuTime() + getSigmaTime() * getQuantile( ((double)m)/(2*N) );
+ /**
+ * See interface <code>Cluster</code>
+ *
+ * @param instance
+ * @return double value
+ */
+ @Override
+ public double getInclusionProbability(Instance instance) {
+ // trivial cluster
+ if (N == 1) {
+ double distance = 0.0;
+ for (int i = 0; i < LS.length; i++) {
+ double d = LS[i] - instance.value(i);
+ distance += d * d;
+ }
+ distance = Math.sqrt(distance);
+ if (distance < EPSILON)
+ return 1.0;
+ return 0.0;
}
-
- private double getMuTime() {
- return LST / N;
+ else {
+ double dist = calcNormalizedDistance(instance.toDoubleArray());
+ if (dist <= getRadius()) {
+ return 1;
+ }
+ else {
+ return 0;
+ }
+ // double res = AuxiliaryFunctions.distanceProbabilty(dist, LS.length);
+ // return res;
}
+ }
- private double getSigmaTime() {
- return Math.sqrt(SST/N - (LST/N)*(LST/N));
- }
+ private double[] getVarianceVector() {
+ double[] res = new double[this.LS.length];
+ for (int i = 0; i < this.LS.length; i++) {
+ double ls = this.LS[i];
+ double ss = this.SS[i];
- private double getQuantile( double z ) {
- assert( z >= 0 && z <= 1 );
- return Math.sqrt( 2 ) * inverseError( 2*z - 1 );
- }
+ double lsDivN = ls / this.getWeight();
+ double lsDivNSquared = lsDivN * lsDivN;
+ double ssDivN = ss / this.getWeight();
+ res[i] = ssDivN - lsDivNSquared;
- @Override
- public double getRadius() {
- //trivial cluster
- if(N == 1) return 0;
- if(t==1)
- t=1;
-
- return getDeviation()*radiusFactor;
- }
-
- @Override
- public CFCluster getCF(){
- return this;
- }
-
-
- private double getDeviation(){
- double[] variance = getVarianceVector();
- double sumOfDeviation = 0.0;
- for (double aVariance : variance) {
- double d = Math.sqrt(aVariance);
- sumOfDeviation += d;
+ // Due to numerical errors, small negative values can occur.
+ // We correct this by settings them to almost zero.
+ if (res[i] <= 0.0) {
+ if (res[i] > -EPSILON) {
+ res[i] = MIN_VARIANCE;
}
- return sumOfDeviation / variance.length;
+ }
+ }
+ return res;
+ }
+
+ /**
+ * Check if this cluster is empty or not.
+ *
+ * @return <code>true</code> if the cluster has no data points,
+ * <code>false</code> otherwise.
+ */
+ public boolean isEmpty() {
+ return this.N == 0;
+ }
+
+ /**
+ * Calculate the normalized euclidean distance (Mahalanobis distance for
+ * distribution w/o covariances) to a point.
+ *
+ * @param point
+ * The point to which the distance is calculated.
+ * @return The normalized distance to the cluster center.
+ *
+ * TODO: check whether WEIGHTING is correctly applied to variances
+ */
+ // ???????
+ private double calcNormalizedDistance(double[] point) {
+ double[] center = getCenter();
+ double res = 0.0;
+
+ for (int i = 0; i < center.length; i++) {
+ double diff = center[i] - point[i];
+ res += (diff * diff);// variance[i];
+ }
+ return Math.sqrt(res);
+ }
+
+ /**
+ * Approximates the inverse error function. Clustream needs this.
+ *
+ * @param x
+ */
+ public static double inverseError(double x) {
+ double z = Math.sqrt(Math.PI) * x;
+ double res = (z) / 2;
+
+ double z2 = z * z;
+ double zProd = z * z2; // z^3
+ res += (1.0 / 24) * zProd;
+
+ zProd *= z2; // z^5
+ res += (7.0 / 960) * zProd;
+
+ zProd *= z2; // z^7
+ res += (127 * zProd) / 80640;
+
+ zProd *= z2; // z^9
+ res += (4369 * zProd) / 11612160;
+
+ zProd *= z2; // z^11
+ res += (34807 * zProd) / 364953600;
+
+ zProd *= z2; // z^13
+ res += (20036983 * zProd) / 797058662400d;
+
+ return res;
+ }
+
+ @Override
+ protected void getClusterSpecificInfo(List<String> infoTitle, List<String> infoValue) {
+ super.getClusterSpecificInfo(infoTitle, infoValue);
+ infoTitle.add("Deviation");
+
+ double[] variance = getVarianceVector();
+ double sumOfDeviation = 0.0;
+ for (double aVariance : variance) {
+ double d = Math.sqrt(aVariance);
+ sumOfDeviation += d;
}
- /**
- * @return this kernels' center
- */
- @Override
- public double[] getCenter() {
- assert (!this.isEmpty());
- double res[] = new double[this.LS.length];
- for (int i = 0; i < res.length; i++) {
- res[i] = this.LS[i] / N;
- }
- return res;
- }
+ sumOfDeviation /= variance.length;
- /**
- * See interface <code>Cluster</code>
- * @param instance
- * @return double value
- */
- @Override
- public double getInclusionProbability(Instance instance) {
- //trivial cluster
- if(N == 1){
- double distance = 0.0;
- for (int i = 0; i < LS.length; i++) {
- double d = LS[i] - instance.value(i);
- distance += d * d;
- }
- distance = Math.sqrt(distance);
- if( distance < EPSILON )
- return 1.0;
- return 0.0;
- }
- else{
- double dist = calcNormalizedDistance(instance.toDoubleArray());
- if(dist <= getRadius()){
- return 1;
- }
- else{
- return 0;
- }
-// double res = AuxiliaryFunctions.distanceProbabilty(dist, LS.length);
-// return res;
- }
- }
-
- private double[] getVarianceVector() {
- double[] res = new double[this.LS.length];
- for (int i = 0; i < this.LS.length; i++) {
- double ls = this.LS[i];
- double ss = this.SS[i];
-
- double lsDivN = ls / this.getWeight();
- double lsDivNSquared = lsDivN * lsDivN;
- double ssDivN = ss / this.getWeight();
- res[i] = ssDivN - lsDivNSquared;
-
- // Due to numerical errors, small negative values can occur.
- // We correct this by settings them to almost zero.
- if (res[i] <= 0.0) {
- if (res[i] > -EPSILON) {
- res[i] = MIN_VARIANCE;
- }
- }
- }
- return res;
- }
-
- /**
- * Check if this cluster is empty or not.
- * @return <code>true</code> if the cluster has no data points,
- * <code>false</code> otherwise.
- */
- public boolean isEmpty() {
- return this.N == 0;
- }
-
- /**
- * Calculate the normalized euclidean distance (Mahalanobis distance for
- * distribution w/o covariances) to a point.
- * @param point The point to which the distance is calculated.
- * @return The normalized distance to the cluster center.
- *
- * TODO: check whether WEIGHTING is correctly applied to variances
- */
- //???????
- private double calcNormalizedDistance(double[] point) {
- double[] center = getCenter();
- double res = 0.0;
-
- for (int i = 0; i < center.length; i++) {
- double diff = center[i] - point[i];
- res += (diff * diff);// variance[i];
- }
- return Math.sqrt(res);
- }
-
- /**
- * Approximates the inverse error function. Clustream needs this.
- * @param x
- */
- public static double inverseError(double x) {
- double z = Math.sqrt(Math.PI) * x;
- double res = (z) / 2;
-
- double z2 = z * z;
- double zProd = z * z2; // z^3
- res += (1.0 / 24) * zProd;
-
- zProd *= z2; // z^5
- res += (7.0 / 960) * zProd;
-
- zProd *= z2; // z^7
- res += (127 * zProd) / 80640;
-
- zProd *= z2; // z^9
- res += (4369 * zProd) / 11612160;
-
- zProd *= z2; // z^11
- res += (34807 * zProd) / 364953600;
-
- zProd *= z2; // z^13
- res += (20036983 * zProd) / 797058662400d;
-
- return res;
- }
-
- @Override
- protected void getClusterSpecificInfo(List<String> infoTitle, List<String> infoValue) {
- super.getClusterSpecificInfo(infoTitle, infoValue);
- infoTitle.add("Deviation");
-
- double[] variance = getVarianceVector();
- double sumOfDeviation = 0.0;
- for (double aVariance : variance) {
- double d = Math.sqrt(aVariance);
- sumOfDeviation += d;
- }
-
- sumOfDeviation/= variance.length;
-
- infoValue.add(Double.toString(sumOfDeviation));
- }
+ infoValue.add(Double.toString(sumOfDeviation));
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/WithKmeans.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/WithKmeans.java
index 9e40fc1..08cf936 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/WithKmeans.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/clusterers/clustream/WithKmeans.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.clusterers.clustream;
/*
@@ -38,431 +37,434 @@
import com.yahoo.labs.samoa.instances.Instance;
public class WithKmeans extends AbstractClusterer {
-
- private static final long serialVersionUID = 1L;
- public IntOption timeWindowOption = new IntOption("horizon",
- 'h', "Rang of the window.", 1000);
+ private static final long serialVersionUID = 1L;
- public IntOption maxNumKernelsOption = new IntOption(
- "maxNumKernels", 'm',
- "Maximum number of micro kernels to use.", 100);
+ public IntOption timeWindowOption = new IntOption("horizon",
+ 'h', "Rang of the window.", 1000);
- public IntOption kernelRadiFactorOption = new IntOption(
- "kernelRadiFactor", 't',
- "Multiplier for the kernel radius", 2);
-
- public IntOption kOption = new IntOption(
- "k", 'k',
- "k of macro k-means (number of clusters)", 5);
+ public IntOption maxNumKernelsOption = new IntOption(
+ "maxNumKernels", 'm',
+ "Maximum number of micro kernels to use.", 100);
- private int timeWindow;
- private long timestamp = -1;
- private ClustreamKernel[] kernels;
- private boolean initialized;
- private List<ClustreamKernel> buffer; // Buffer for initialization with kNN
- private int bufferSize;
- private double t;
- private int m;
-
- public WithKmeans() {
-
- }
+ public IntOption kernelRadiFactorOption = new IntOption(
+ "kernelRadiFactor", 't',
+ "Multiplier for the kernel radius", 2);
- @Override
- public void resetLearningImpl() {
- this.kernels = new ClustreamKernel[maxNumKernelsOption.getValue()];
- this.timeWindow = timeWindowOption.getValue();
- this.initialized = false;
- this.buffer = new LinkedList<ClustreamKernel>();
- this.bufferSize = maxNumKernelsOption.getValue();
- t = kernelRadiFactorOption.getValue();
- m = maxNumKernelsOption.getValue();
- }
+ public IntOption kOption = new IntOption(
+ "k", 'k',
+ "k of macro k-means (number of clusters)", 5);
- @Override
- public void trainOnInstanceImpl(Instance instance) {
- int dim = instance.numValues();
- timestamp++;
- // 0. Initialize
- if (!initialized) {
- if (buffer.size() < bufferSize) {
- buffer.add(new ClustreamKernel(instance, dim, timestamp, t, m));
- return;
- } else {
- for (int i = 0; i < buffer.size(); i++) {
- kernels[i] = new ClustreamKernel(new DenseInstance(1.0, buffer.get(i).getCenter()), dim, timestamp, t, m);
- }
-
- buffer.clear();
- initialized = true;
- return;
- }
- }
+ private int timeWindow;
+ private long timestamp = -1;
+ private ClustreamKernel[] kernels;
+ private boolean initialized;
+ private List<ClustreamKernel> buffer; // Buffer for initialization with kNN
+ private int bufferSize;
+ private double t;
+ private int m;
+ public WithKmeans() {
- // 1. Determine closest kernel
- ClustreamKernel closestKernel = null;
- double minDistance = Double.MAX_VALUE;
- for ( int i = 0; i < kernels.length; i++ ) {
- //System.out.println(i+" "+kernels[i].getWeight()+" "+kernels[i].getDeviation());
- double distance = distance(instance.toDoubleArray(), kernels[i].getCenter());
- if (distance < minDistance) {
- closestKernel = kernels[i];
- minDistance = distance;
- }
- }
+ }
- // 2. Check whether instance fits into closestKernel
- double radius = 0.0;
- if ( closestKernel.getWeight() == 1 ) {
- // Special case: estimate radius by determining the distance to the
- // next closest cluster
- radius = Double.MAX_VALUE;
- double[] center = closestKernel.getCenter();
- for ( int i = 0; i < kernels.length; i++ ) {
- if ( kernels[i] == closestKernel ) {
- continue;
- }
+ @Override
+ public void resetLearningImpl() {
+ this.kernels = new ClustreamKernel[maxNumKernelsOption.getValue()];
+ this.timeWindow = timeWindowOption.getValue();
+ this.initialized = false;
+ this.buffer = new LinkedList<ClustreamKernel>();
+ this.bufferSize = maxNumKernelsOption.getValue();
+ t = kernelRadiFactorOption.getValue();
+ m = maxNumKernelsOption.getValue();
+ }
- double distance = distance(kernels[i].getCenter(), center );
- radius = Math.min( distance, radius );
- }
- } else {
- radius = closestKernel.getRadius();
- }
-
- if ( minDistance < radius ) {
- // Date fits, put into kernel and be happy
- closestKernel.insert( instance, timestamp );
- return;
- }
-
- // 3. Date does not fit, we need to free
- // some space to insert a new kernel
- long threshold = timestamp - timeWindow; // Kernels before this can be forgotten
-
- // 3.1 Try to forget old kernels
- for ( int i = 0; i < kernels.length; i++ ) {
- if ( kernels[i].getRelevanceStamp() < threshold ) {
- kernels[i] = new ClustreamKernel( instance, dim, timestamp, t, m );
- return;
- }
- }
-
- // 3.2 Merge closest two kernels
- int closestA = 0;
- int closestB = 0;
- minDistance = Double.MAX_VALUE;
- for ( int i = 0; i < kernels.length; i++ ) {
- double[] centerA = kernels[i].getCenter();
- for ( int j = i + 1; j < kernels.length; j++ ) {
- double dist = distance( centerA, kernels[j].getCenter() );
- if ( dist < minDistance ) {
- minDistance = dist;
- closestA = i;
- closestB = j;
- }
- }
- }
- assert (closestA != closestB);
-
- kernels[closestA].add( kernels[closestB] );
- kernels[closestB] = new ClustreamKernel( instance, dim, timestamp, t, m );
- }
-
- @Override
- public Clustering getMicroClusteringResult() {
- if (!initialized) {
- return new Clustering(new Cluster[0]);
- }
-
- ClustreamKernel[] result = new ClustreamKernel[kernels.length];
- for (int i = 0; i < result.length; i++) {
- result[i] = new ClustreamKernel(kernels[i], t, m);
- }
-
- return new Clustering(result);
- }
-
- @Override
- public Clustering getClusteringResult() {
- return kMeans_rand(kOption.getValue(), getMicroClusteringResult());
- }
-
- public Clustering getClusteringResult(Clustering gtClustering) {
- return kMeans_gta(kOption.getValue(), getMicroClusteringResult(), gtClustering);
- }
-
- public String getName() {
- return "CluStreamWithKMeans " + timeWindow;
- }
-
- /**
- * Distance between two vectors.
- *
- * @param pointA
- * @param pointB
- * @return dist
- */
- private static double distance(double[] pointA, double [] pointB) {
- double distance = 0.0;
- for (int i = 0; i < pointA.length; i++) {
- double d = pointA[i] - pointB[i];
- distance += d * d;
- }
- return Math.sqrt(distance);
- }
-
- /**
- * k-means of (micro)clusters, with ground-truth-aided initialization.
- * (to produce best results)
- *
- * @param k
- * @param data
- * @return (macro)clustering - CFClusters
- */
- public static Clustering kMeans_gta(int k, Clustering clustering, Clustering gtClustering) {
-
- ArrayList<CFCluster> microclusters = new ArrayList<CFCluster>();
- for (int i = 0; i < clustering.size(); i++) {
- if (clustering.get(i) instanceof CFCluster) {
- microclusters.add((CFCluster)clustering.get(i));
- } else {
- System.out.println("Unsupported Cluster Type:" + clustering.get(i).getClass() + ". Cluster needs to extend moa.cluster.CFCluster");
- }
+ @Override
+ public void trainOnInstanceImpl(Instance instance) {
+ int dim = instance.numValues();
+ timestamp++;
+ // 0. Initialize
+ if (!initialized) {
+ if (buffer.size() < bufferSize) {
+ buffer.add(new ClustreamKernel(instance, dim, timestamp, t, m));
+ return;
+ } else {
+ for (int i = 0; i < buffer.size(); i++) {
+ kernels[i] = new ClustreamKernel(new DenseInstance(1.0, buffer.get(i).getCenter()), dim, timestamp, t, m);
}
-
- int n = microclusters.size();
- assert (k <= n);
-
- /* k-means */
- Random random = new Random(0);
- Cluster[] centers = new Cluster[k];
- int K = gtClustering.size();
-
- for (int i = 0; i < k; i++) {
- if (i < K) { // GT-aided
- centers[i] = new SphereCluster(gtClustering.get(i).getCenter(), 0);
- } else { // Randomized
- int rid = random.nextInt(n);
- centers[i] = new SphereCluster(microclusters.get(rid).getCenter(), 0);
- }
- }
-
- return cleanUpKMeans(kMeans(k, centers, microclusters), microclusters);
- }
-
- /**
- * k-means of (micro)clusters, with randomized initialization.
- *
- * @param k
- * @param data
- * @return (macro)clustering - CFClusters
- */
- public static Clustering kMeans_rand(int k, Clustering clustering) {
-
- ArrayList<CFCluster> microclusters = new ArrayList<CFCluster>();
- for (int i = 0; i < clustering.size(); i++) {
- if (clustering.get(i) instanceof CFCluster) {
- microclusters.add((CFCluster)clustering.get(i));
- } else {
- System.out.println("Unsupported Cluster Type:" + clustering.get(i).getClass() + ". Cluster needs to extend moa.cluster.CFCluster");
- }
+
+ buffer.clear();
+ initialized = true;
+ return;
+ }
+ }
+
+ // 1. Determine closest kernel
+ ClustreamKernel closestKernel = null;
+ double minDistance = Double.MAX_VALUE;
+ for (int i = 0; i < kernels.length; i++) {
+ // System.out.println(i+" "+kernels[i].getWeight()+" "+kernels[i].getDeviation());
+ double distance = distance(instance.toDoubleArray(), kernels[i].getCenter());
+ if (distance < minDistance) {
+ closestKernel = kernels[i];
+ minDistance = distance;
+ }
+ }
+
+ // 2. Check whether instance fits into closestKernel
+ double radius = 0.0;
+ if (closestKernel.getWeight() == 1) {
+ // Special case: estimate radius by determining the distance to the
+ // next closest cluster
+ radius = Double.MAX_VALUE;
+ double[] center = closestKernel.getCenter();
+ for (int i = 0; i < kernels.length; i++) {
+ if (kernels[i] == closestKernel) {
+ continue;
}
-
- int n = microclusters.size();
- assert (k <= n);
-
- /* k-means */
- Random random = new Random(0);
- Cluster[] centers = new Cluster[k];
-
- for (int i = 0; i < k; i++) {
- int rid = random.nextInt(n);
- centers[i] = new SphereCluster(microclusters.get(rid).getCenter(), 0);
- }
-
- return cleanUpKMeans(kMeans(k, centers, microclusters), microclusters);
- }
-
- /**
- * (The Actual Algorithm) k-means of (micro)clusters, with specified initialization points.
- *
- * @param k
- * @param centers - initial centers
- * @param data
- * @return (macro)clustering - SphereClusters
- */
- protected static Clustering kMeans(int k, Cluster[] centers, List<? extends Cluster> data) {
- assert (centers.length == k);
- assert (k > 0);
- int dimensions = centers[0].getCenter().length;
+ double distance = distance(kernels[i].getCenter(), center);
+ radius = Math.min(distance, radius);
+ }
+ } else {
+ radius = closestKernel.getRadius();
+ }
- ArrayList<ArrayList<Cluster>> clustering = new ArrayList<ArrayList<Cluster>>();
- for (int i = 0; i < k; i++) {
- clustering.add(new ArrayList<Cluster>());
- }
+ if (minDistance < radius) {
+ // Date fits, put into kernel and be happy
+ closestKernel.insert(instance, timestamp);
+ return;
+ }
- while (true) {
- // Assign points to clusters
- for (Cluster point : data) {
- double minDistance = distance(point.getCenter(), centers[0].getCenter());
- int closestCluster = 0;
- for (int i = 1; i < k; i++) {
- double distance = distance(point.getCenter(), centers[i].getCenter());
- if (distance < minDistance) {
- closestCluster = i;
- minDistance = distance;
- }
- }
+ // 3. Date does not fit, we need to free
+ // some space to insert a new kernel
+ long threshold = timestamp - timeWindow; // Kernels before this can be
+ // forgotten
- clustering.get(closestCluster).add(point);
- }
+ // 3.1 Try to forget old kernels
+ for (int i = 0; i < kernels.length; i++) {
+ if (kernels[i].getRelevanceStamp() < threshold) {
+ kernels[i] = new ClustreamKernel(instance, dim, timestamp, t, m);
+ return;
+ }
+ }
- // Calculate new centers and clear clustering lists
- SphereCluster[] newCenters = new SphereCluster[centers.length];
- for (int i = 0; i < k; i++) {
- newCenters[i] = calculateCenter(clustering.get(i), dimensions);
- clustering.get(i).clear();
- }
-
- // Convergence check
- boolean converged = true;
- for (int i = 0; i < k; i++) {
- if (!Arrays.equals(centers[i].getCenter(), newCenters[i].getCenter())) {
- converged = false;
- break;
- }
- }
-
- if (converged) {
- break;
- } else {
- centers = newCenters;
- }
- }
+ // 3.2 Merge closest two kernels
+ int closestA = 0;
+ int closestB = 0;
+ minDistance = Double.MAX_VALUE;
+ for (int i = 0; i < kernels.length; i++) {
+ double[] centerA = kernels[i].getCenter();
+ for (int j = i + 1; j < kernels.length; j++) {
+ double dist = distance(centerA, kernels[j].getCenter());
+ if (dist < minDistance) {
+ minDistance = dist;
+ closestA = i;
+ closestB = j;
+ }
+ }
+ }
+ assert (closestA != closestB);
- return new Clustering(centers);
- }
-
- /**
- * Rearrange the k-means result into a set of CFClusters, cleaning up the redundancies.
- *
- * @param kMeansResult
- * @param microclusters
- * @return
- */
- protected static Clustering cleanUpKMeans(Clustering kMeansResult, ArrayList<CFCluster> microclusters) {
- /* Convert k-means result to CFClusters */
- int k = kMeansResult.size();
- CFCluster[] converted = new CFCluster[k];
+ kernels[closestA].add(kernels[closestB]);
+ kernels[closestB] = new ClustreamKernel(instance, dim, timestamp, t, m);
+ }
- for (CFCluster mc : microclusters) {
- // Find closest kMeans cluster
- double minDistance = Double.MAX_VALUE;
- int closestCluster = 0;
- for (int i = 0; i < k; i++) {
- double distance = distance(kMeansResult.get(i).getCenter(), mc.getCenter());
- if (distance < minDistance) {
- closestCluster = i;
- minDistance = distance;
- }
- }
+ @Override
+ public Clustering getMicroClusteringResult() {
+ if (!initialized) {
+ return new Clustering(new Cluster[0]);
+ }
- // Add to cluster
- if ( converted[closestCluster] == null ) {
- converted[closestCluster] = (CFCluster)mc.copy();
- } else {
- converted[closestCluster].add(mc);
- }
- }
+ ClustreamKernel[] result = new ClustreamKernel[kernels.length];
+ for (int i = 0; i < result.length; i++) {
+ result[i] = new ClustreamKernel(kernels[i], t, m);
+ }
- // Clean up
- int count = 0;
- for (int i = 0; i < converted.length; i++) {
- if (converted[i] != null)
- count++;
- }
+ return new Clustering(result);
+ }
- CFCluster[] cleaned = new CFCluster[count];
- count = 0;
- for (int i = 0; i < converted.length; i++) {
- if (converted[i] != null)
- cleaned[count++] = converted[i];
- }
+ @Override
+ public Clustering getClusteringResult() {
+ return kMeans_rand(kOption.getValue(), getMicroClusteringResult());
+ }
- return new Clustering(cleaned);
- }
+ public Clustering getClusteringResult(Clustering gtClustering) {
+ return kMeans_gta(kOption.getValue(), getMicroClusteringResult(), gtClustering);
+ }
-
+ public String getName() {
+ return "CluStreamWithKMeans " + timeWindow;
+ }
- /**
- * k-means helper: Calculate a wrapping cluster of assigned points[microclusters].
- *
- * @param assigned
- * @param dimensions
- * @return SphereCluster (with center and radius)
- */
- private static SphereCluster calculateCenter(ArrayList<Cluster> assigned, int dimensions) {
- double[] result = new double[dimensions];
- for (int i = 0; i < result.length; i++) {
- result[i] = 0.0;
- }
+ /**
+ * Distance between two vectors.
+ *
+ * @param pointA
+ * @param pointB
+ * @return dist
+ */
+ private static double distance(double[] pointA, double[] pointB) {
+ double distance = 0.0;
+ for (int i = 0; i < pointA.length; i++) {
+ double d = pointA[i] - pointB[i];
+ distance += d * d;
+ }
+ return Math.sqrt(distance);
+ }
- if (assigned.size() == 0) {
- return new SphereCluster(result, 0.0);
- }
+ /**
+ * k-means of (micro)clusters, with ground-truth-aided initialization. (to
+ * produce best results)
+ *
+ * @param k
+ * @param data
+ * @return (macro)clustering - CFClusters
+ */
+ public static Clustering kMeans_gta(int k, Clustering clustering, Clustering gtClustering) {
- for (Cluster point : assigned) {
- double[] center = point.getCenter();
- for (int i = 0; i < result.length; i++) {
- result[i] += center[i];
- }
- }
+ ArrayList<CFCluster> microclusters = new ArrayList<CFCluster>();
+ for (int i = 0; i < clustering.size(); i++) {
+ if (clustering.get(i) instanceof CFCluster) {
+ microclusters.add((CFCluster) clustering.get(i));
+ } else {
+ System.out.println("Unsupported Cluster Type:" + clustering.get(i).getClass()
+ + ". Cluster needs to extend moa.cluster.CFCluster");
+ }
+ }
- // Normalize
- for (int i = 0; i < result.length; i++) {
- result[i] /= assigned.size();
- }
+ int n = microclusters.size();
+ assert (k <= n);
- // Calculate radius: biggest wrapping distance from center
- double radius = 0.0;
- for (Cluster point : assigned) {
- double dist = distance(result, point.getCenter());
- if (dist > radius) {
- radius = dist;
- }
- }
- SphereCluster sc = new SphereCluster(result, radius);
- sc.setWeight(assigned.size());
- return sc;
- }
+ /* k-means */
+ Random random = new Random(0);
+ Cluster[] centers = new Cluster[k];
+ int K = gtClustering.size();
-
- /** Miscellaneous **/
-
- @Override
- public boolean implementsMicroClusterer() {
- return true;
- }
-
- public boolean isRandomizable() {
- return false;
- }
-
- public double[] getVotesForInstance(Instance inst) {
- throw new UnsupportedOperationException("Not supported yet.");
- }
-
- @Override
- protected Measurement[] getModelMeasurementsImpl() {
- throw new UnsupportedOperationException("Not supported yet.");
- }
-
- @Override
- public void getModelDescription(StringBuilder out, int indent) {
- throw new UnsupportedOperationException("Not supported yet.");
- }
+ for (int i = 0; i < k; i++) {
+ if (i < K) { // GT-aided
+ centers[i] = new SphereCluster(gtClustering.get(i).getCenter(), 0);
+ } else { // Randomized
+ int rid = random.nextInt(n);
+ centers[i] = new SphereCluster(microclusters.get(rid).getCenter(), 0);
+ }
+ }
+
+ return cleanUpKMeans(kMeans(k, centers, microclusters), microclusters);
+ }
+
+ /**
+ * k-means of (micro)clusters, with randomized initialization.
+ *
+ * @param k
+ * @param data
+ * @return (macro)clustering - CFClusters
+ */
+ public static Clustering kMeans_rand(int k, Clustering clustering) {
+
+ ArrayList<CFCluster> microclusters = new ArrayList<CFCluster>();
+ for (int i = 0; i < clustering.size(); i++) {
+ if (clustering.get(i) instanceof CFCluster) {
+ microclusters.add((CFCluster) clustering.get(i));
+ } else {
+ System.out.println("Unsupported Cluster Type:" + clustering.get(i).getClass()
+ + ". Cluster needs to extend moa.cluster.CFCluster");
+ }
+ }
+
+ int n = microclusters.size();
+ assert (k <= n);
+
+ /* k-means */
+ Random random = new Random(0);
+ Cluster[] centers = new Cluster[k];
+
+ for (int i = 0; i < k; i++) {
+ int rid = random.nextInt(n);
+ centers[i] = new SphereCluster(microclusters.get(rid).getCenter(), 0);
+ }
+
+ return cleanUpKMeans(kMeans(k, centers, microclusters), microclusters);
+ }
+
+ /**
+ * (The Actual Algorithm) k-means of (micro)clusters, with specified
+ * initialization points.
+ *
+ * @param k
+ * @param centers
+ * - initial centers
+ * @param data
+ * @return (macro)clustering - SphereClusters
+ */
+ protected static Clustering kMeans(int k, Cluster[] centers, List<? extends Cluster> data) {
+ assert (centers.length == k);
+ assert (k > 0);
+
+ int dimensions = centers[0].getCenter().length;
+
+ ArrayList<ArrayList<Cluster>> clustering = new ArrayList<ArrayList<Cluster>>();
+ for (int i = 0; i < k; i++) {
+ clustering.add(new ArrayList<Cluster>());
+ }
+
+ while (true) {
+ // Assign points to clusters
+ for (Cluster point : data) {
+ double minDistance = distance(point.getCenter(), centers[0].getCenter());
+ int closestCluster = 0;
+ for (int i = 1; i < k; i++) {
+ double distance = distance(point.getCenter(), centers[i].getCenter());
+ if (distance < minDistance) {
+ closestCluster = i;
+ minDistance = distance;
+ }
+ }
+
+ clustering.get(closestCluster).add(point);
+ }
+
+ // Calculate new centers and clear clustering lists
+ SphereCluster[] newCenters = new SphereCluster[centers.length];
+ for (int i = 0; i < k; i++) {
+ newCenters[i] = calculateCenter(clustering.get(i), dimensions);
+ clustering.get(i).clear();
+ }
+
+ // Convergence check
+ boolean converged = true;
+ for (int i = 0; i < k; i++) {
+ if (!Arrays.equals(centers[i].getCenter(), newCenters[i].getCenter())) {
+ converged = false;
+ break;
+ }
+ }
+
+ if (converged) {
+ break;
+ } else {
+ centers = newCenters;
+ }
+ }
+
+ return new Clustering(centers);
+ }
+
+ /**
+ * Rearrange the k-means result into a set of CFClusters, cleaning up the
+ * redundancies.
+ *
+ * @param kMeansResult
+ * @param microclusters
+ * @return
+ */
+ protected static Clustering cleanUpKMeans(Clustering kMeansResult, ArrayList<CFCluster> microclusters) {
+ /* Convert k-means result to CFClusters */
+ int k = kMeansResult.size();
+ CFCluster[] converted = new CFCluster[k];
+
+ for (CFCluster mc : microclusters) {
+ // Find closest kMeans cluster
+ double minDistance = Double.MAX_VALUE;
+ int closestCluster = 0;
+ for (int i = 0; i < k; i++) {
+ double distance = distance(kMeansResult.get(i).getCenter(), mc.getCenter());
+ if (distance < minDistance) {
+ closestCluster = i;
+ minDistance = distance;
+ }
+ }
+
+ // Add to cluster
+ if (converted[closestCluster] == null) {
+ converted[closestCluster] = (CFCluster) mc.copy();
+ } else {
+ converted[closestCluster].add(mc);
+ }
+ }
+
+ // Clean up
+ int count = 0;
+ for (int i = 0; i < converted.length; i++) {
+ if (converted[i] != null)
+ count++;
+ }
+
+ CFCluster[] cleaned = new CFCluster[count];
+ count = 0;
+ for (int i = 0; i < converted.length; i++) {
+ if (converted[i] != null)
+ cleaned[count++] = converted[i];
+ }
+
+ return new Clustering(cleaned);
+ }
+
+ /**
+ * k-means helper: Calculate a wrapping cluster of assigned
+ * points[microclusters].
+ *
+ * @param assigned
+ * @param dimensions
+ * @return SphereCluster (with center and radius)
+ */
+ private static SphereCluster calculateCenter(ArrayList<Cluster> assigned, int dimensions) {
+ double[] result = new double[dimensions];
+ for (int i = 0; i < result.length; i++) {
+ result[i] = 0.0;
+ }
+
+ if (assigned.size() == 0) {
+ return new SphereCluster(result, 0.0);
+ }
+
+ for (Cluster point : assigned) {
+ double[] center = point.getCenter();
+ for (int i = 0; i < result.length; i++) {
+ result[i] += center[i];
+ }
+ }
+
+ // Normalize
+ for (int i = 0; i < result.length; i++) {
+ result[i] /= assigned.size();
+ }
+
+ // Calculate radius: biggest wrapping distance from center
+ double radius = 0.0;
+ for (Cluster point : assigned) {
+ double dist = distance(result, point.getCenter());
+ if (dist > radius) {
+ radius = dist;
+ }
+ }
+ SphereCluster sc = new SphereCluster(result, radius);
+ sc.setWeight(assigned.size());
+ return sc;
+ }
+
+ /** Miscellaneous **/
+
+ @Override
+ public boolean implementsMicroClusterer() {
+ return true;
+ }
+
+ public boolean isRandomizable() {
+ return false;
+ }
+
+ public double[] getVotesForInstance(Instance inst) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
+
+ @Override
+ protected Measurement[] getModelMeasurementsImpl() {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
+
+ @Override
+ public void getModelDescription(StringBuilder out, int indent) {
+ throw new UnsupportedOperationException("Not supported yet.");
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/AutoClassDiscovery.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/AutoClassDiscovery.java
index f052832..3880e09 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/AutoClassDiscovery.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/AutoClassDiscovery.java
@@ -36,176 +36,161 @@
/**
* Class for discovering classes via reflection in the java class path.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class AutoClassDiscovery {
- protected static final Map<String, String[]> cachedClassNames = new HashMap<String, String[]>();
+ protected static final Map<String, String[]> cachedClassNames = new HashMap<String, String[]>();
- public static String[] findClassNames(String packageNameToSearch) {
- String[] cached = cachedClassNames.get(packageNameToSearch);
- if (cached == null) {
- HashSet<String> classNames = new HashSet<String>();
- /*StringTokenizer pathTokens = new StringTokenizer(System
- .getProperty("java.class.path"), File.pathSeparator);*/
- String packageDirName = packageNameToSearch.replace('.',
- File.separatorChar);
- String packageJarName = packageNameToSearch.length() > 0 ? (packageNameToSearch.replace('.', '/') + "/")
- : "";
- String part = "";
+ public static String[] findClassNames(String packageNameToSearch) {
+ String[] cached = cachedClassNames.get(packageNameToSearch);
+ if (cached == null) {
+ HashSet<String> classNames = new HashSet<String>();
+ /*
+ * StringTokenizer pathTokens = new StringTokenizer(System
+ * .getProperty("java.class.path"), File.pathSeparator);
+ */
+ String packageDirName = packageNameToSearch.replace('.',
+ File.separatorChar);
+ String packageJarName = packageNameToSearch.length() > 0 ? (packageNameToSearch.replace('.', '/') + "/")
+ : "";
+ String part = "";
+ AutoClassDiscovery adc = new AutoClassDiscovery();
+ URLClassLoader sysLoader = (URLClassLoader) adc.getClass().getClassLoader();
+ URL[] cl_urls = sysLoader.getURLs();
- AutoClassDiscovery adc = new AutoClassDiscovery();
- URLClassLoader sysLoader = (URLClassLoader) adc.getClass().getClassLoader();
- URL[] cl_urls = sysLoader.getURLs();
+ for (int i = 0; i < cl_urls.length; i++) {
+ part = cl_urls[i].toString();
+ if (part.startsWith("file:")) {
+ part = part.replace(" ", "%20");
+ try {
+ File temp = new File(new java.net.URI(part));
+ part = temp.getAbsolutePath();
+ } catch (URISyntaxException e) {
+ e.printStackTrace();
+ }
+ }
- for (int i = 0; i < cl_urls.length; i++) {
- part = cl_urls[i].toString();
- if (part.startsWith("file:")) {
- part = part.replace(" ", "%20");
- try {
- File temp = new File(new java.net.URI(part));
- part = temp.getAbsolutePath();
- } catch (URISyntaxException e) {
- e.printStackTrace();
- }
- }
-
- // find classes
- ArrayList<File> files = new ArrayList<File>();
- File dir = new File(part);
- if (dir.isDirectory()) {
- File root = new File(dir.toString() + File.separatorChar + packageDirName);
- String[] names = findClassesInDirectoryRecursive(root, "");
- classNames.addAll(Arrays.asList(names));
- } else {
- try {
- JarFile jar = new JarFile(part);
- Enumeration<JarEntry> jarEntries = jar.entries();
- while (jarEntries.hasMoreElements()) {
- String jarEntry = jarEntries.nextElement().getName();
- if (jarEntry.startsWith(packageJarName)) {
- String relativeName = jarEntry.substring(packageJarName.length());
- if (relativeName.endsWith(".class")) {
- relativeName = relativeName.replace('/',
- '.');
- classNames.add(relativeName.substring(0,
- relativeName.length()
- - ".class".length()));
- }
- }
- }
- } catch (IOException ignored) {
- // ignore unreadable files
- }
- }
- }
-
- /*while (pathTokens.hasMoreElements()) {
- String pathToSearch = pathTokens.nextElement().toString();
- if (pathToSearch.endsWith(".jar")) {
- try {
- JarFile jar = new JarFile(pathToSearch);
+ // find classes
+ ArrayList<File> files = new ArrayList<File>();
+ File dir = new File(part);
+ if (dir.isDirectory()) {
+ File root = new File(dir.toString() + File.separatorChar + packageDirName);
+ String[] names = findClassesInDirectoryRecursive(root, "");
+ classNames.addAll(Arrays.asList(names));
+ } else {
+ try {
+ JarFile jar = new JarFile(part);
Enumeration<JarEntry> jarEntries = jar.entries();
while (jarEntries.hasMoreElements()) {
- String jarEntry = jarEntries.nextElement()
- .getName();
- if (jarEntry.startsWith(packageJarName)) {
- String relativeName = jarEntry
- .substring(packageJarName.length());
- if (relativeName.endsWith(".class")) {
- relativeName = relativeName.replace('/',
- '.');
- classNames.add(relativeName.substring(0,
- relativeName.length()
- - ".class".length()));
+ String jarEntry = jarEntries.nextElement().getName();
+ if (jarEntry.startsWith(packageJarName)) {
+ String relativeName = jarEntry.substring(packageJarName.length());
+ if (relativeName.endsWith(".class")) {
+ relativeName = relativeName.replace('/',
+ '.');
+ classNames.add(relativeName.substring(0,
+ relativeName.length()
+ - ".class".length()));
+ }
+ }
}
- }
- }
- } catch (IOException ignored) {
+ } catch (IOException ignored) {
// ignore unreadable files
- }
- } else {
- File root = new File(pathToSearch + File.separatorChar
- + packageDirName);
- String[] names = findClassesInDirectoryRecursive(root, "");
- for (String name : names) {
- classNames.add(name);
- }
- }
- } */
- cached = classNames.toArray(new String[classNames.size()]);
- Arrays.sort(cached);
- cachedClassNames.put(packageNameToSearch, cached);
+ }
}
- return cached;
- }
+ }
- protected static String[] findClassesInDirectoryRecursive(File root,
- String packagePath) {
- HashSet<String> classNames = new HashSet<String>();
- if (root.isDirectory()) {
- String[] list = root.list();
- for (String string : list) {
- if (string.endsWith(".class")) {
- classNames.add(packagePath
- + string.substring(0, string.length()
- - ".class".length()));
- } else {
- File testDir = new File(root.getPath() + File.separatorChar
- + string);
- if (testDir.isDirectory()) {
- String[] names = findClassesInDirectoryRecursive(
- testDir, packagePath + string + ".");
- classNames.addAll(Arrays.asList(names));
- }
- }
- }
+ /*
+ * while (pathTokens.hasMoreElements()) { String pathToSearch =
+ * pathTokens.nextElement().toString(); if (pathToSearch.endsWith(".jar"))
+ * { try { JarFile jar = new JarFile(pathToSearch); Enumeration<JarEntry>
+ * jarEntries = jar.entries(); while (jarEntries.hasMoreElements()) {
+ * String jarEntry = jarEntries.nextElement() .getName(); if
+ * (jarEntry.startsWith(packageJarName)) { String relativeName = jarEntry
+ * .substring(packageJarName.length()); if
+ * (relativeName.endsWith(".class")) { relativeName =
+ * relativeName.replace('/', '.');
+ * classNames.add(relativeName.substring(0, relativeName.length() -
+ * ".class".length())); } } } } catch (IOException ignored) { // ignore
+ * unreadable files } } else { File root = new File(pathToSearch +
+ * File.separatorChar + packageDirName); String[] names =
+ * findClassesInDirectoryRecursive(root, ""); for (String name : names) {
+ * classNames.add(name); } } }
+ */
+ cached = classNames.toArray(new String[classNames.size()]);
+ Arrays.sort(cached);
+ cachedClassNames.put(packageNameToSearch, cached);
+ }
+ return cached;
+ }
+
+ protected static String[] findClassesInDirectoryRecursive(File root,
+ String packagePath) {
+ HashSet<String> classNames = new HashSet<String>();
+ if (root.isDirectory()) {
+ String[] list = root.list();
+ for (String string : list) {
+ if (string.endsWith(".class")) {
+ classNames.add(packagePath
+ + string.substring(0, string.length()
+ - ".class".length()));
+ } else {
+ File testDir = new File(root.getPath() + File.separatorChar
+ + string);
+ if (testDir.isDirectory()) {
+ String[] names = findClassesInDirectoryRecursive(
+ testDir, packagePath + string + ".");
+ classNames.addAll(Arrays.asList(names));
+ }
}
- return classNames.toArray(new String[classNames.size()]);
+ }
}
+ return classNames.toArray(new String[classNames.size()]);
+ }
- public static Class[] findClassesOfType(String packageNameToSearch,
- Class<?> typeDesired) {
- ArrayList<Class<?>> classesFound = new ArrayList<Class<?>>();
- String[] classNames = findClassNames(packageNameToSearch);
- for (String className : classNames) {
- String fullName = packageNameToSearch.length() > 0 ? (packageNameToSearch
- + "." + className)
- : className;
- if (isPublicConcreteClassOfType(fullName, typeDesired)) {
- try {
- classesFound.add(Class.forName(fullName));
- } catch (Exception ignored) {
- // ignore classes that we cannot instantiate
- }
- }
- }
- return classesFound.toArray(new Class[classesFound.size()]);
- }
-
- public static boolean isPublicConcreteClassOfType(String className,
- Class<?> typeDesired) {
- Class<?> testClass = null;
+ public static Class[] findClassesOfType(String packageNameToSearch,
+ Class<?> typeDesired) {
+ ArrayList<Class<?>> classesFound = new ArrayList<Class<?>>();
+ String[] classNames = findClassNames(packageNameToSearch);
+ for (String className : classNames) {
+ String fullName = packageNameToSearch.length() > 0 ? (packageNameToSearch
+ + "." + className)
+ : className;
+ if (isPublicConcreteClassOfType(fullName, typeDesired)) {
try {
- testClass = Class.forName(className);
- } catch (Exception e) {
- return false;
- }
- int classModifiers = testClass.getModifiers();
- return (java.lang.reflect.Modifier.isPublic(classModifiers)
- && !java.lang.reflect.Modifier.isAbstract(classModifiers)
- && typeDesired.isAssignableFrom(testClass) && hasEmptyConstructor(testClass));
- }
-
- public static boolean hasEmptyConstructor(Class<?> type) {
- try {
- type.getConstructor();
- return true;
+ classesFound.add(Class.forName(fullName));
} catch (Exception ignored) {
- return false;
+ // ignore classes that we cannot instantiate
}
+ }
}
+ return classesFound.toArray(new Class[classesFound.size()]);
+ }
+
+ public static boolean isPublicConcreteClassOfType(String className,
+ Class<?> typeDesired) {
+ Class<?> testClass = null;
+ try {
+ testClass = Class.forName(className);
+ } catch (Exception e) {
+ return false;
+ }
+ int classModifiers = testClass.getModifiers();
+ return (java.lang.reflect.Modifier.isPublic(classModifiers)
+ && !java.lang.reflect.Modifier.isAbstract(classModifiers)
+ && typeDesired.isAssignableFrom(testClass) && hasEmptyConstructor(testClass));
+ }
+
+ public static boolean hasEmptyConstructor(Class<?> type) {
+ try {
+ type.getConstructor();
+ return true;
+ } catch (Exception ignored) {
+ return false;
+ }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/AutoExpandVector.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/AutoExpandVector.java
index 2b8ed09..071e0c4 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/AutoExpandVector.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/AutoExpandVector.java
@@ -28,106 +28,106 @@
/**
* Vector with the capability of automatic expansion.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class AutoExpandVector<T> extends ArrayList<T> implements MOAObject {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public AutoExpandVector() {
- super(0);
- }
-
- public AutoExpandVector(int size) {
- super(size);
- }
+ public AutoExpandVector() {
+ super(0);
+ }
- @Override
- public void add(int pos, T obj) {
- if (pos > size()) {
- while (pos > size()) {
- add(null);
- }
- trimToSize();
- }
- super.add(pos, obj);
- }
+ public AutoExpandVector(int size) {
+ super(size);
+ }
- @Override
- public T get(int pos) {
- return ((pos >= 0) && (pos < size())) ? super.get(pos) : null;
+ @Override
+ public void add(int pos, T obj) {
+ if (pos > size()) {
+ while (pos > size()) {
+ add(null);
+ }
+ trimToSize();
}
+ super.add(pos, obj);
+ }
- @Override
- public T set(int pos, T obj) {
- if (pos >= size()) {
- add(pos, obj);
- return null;
- }
- return super.set(pos, obj);
- }
+ @Override
+ public T get(int pos) {
+ return ((pos >= 0) && (pos < size())) ? super.get(pos) : null;
+ }
- @Override
- public boolean add(T arg0) {
- boolean result = super.add(arg0);
- trimToSize();
- return result;
+ @Override
+ public T set(int pos, T obj) {
+ if (pos >= size()) {
+ add(pos, obj);
+ return null;
}
+ return super.set(pos, obj);
+ }
- @Override
- public boolean addAll(Collection<? extends T> arg0) {
- boolean result = super.addAll(arg0);
- trimToSize();
- return result;
- }
+ @Override
+ public boolean add(T arg0) {
+ boolean result = super.add(arg0);
+ trimToSize();
+ return result;
+ }
- @Override
- public boolean addAll(int arg0, Collection<? extends T> arg1) {
- boolean result = super.addAll(arg0, arg1);
- trimToSize();
- return result;
- }
+ @Override
+ public boolean addAll(Collection<? extends T> arg0) {
+ boolean result = super.addAll(arg0);
+ trimToSize();
+ return result;
+ }
- @Override
- public void clear() {
- super.clear();
- trimToSize();
- }
+ @Override
+ public boolean addAll(int arg0, Collection<? extends T> arg1) {
+ boolean result = super.addAll(arg0, arg1);
+ trimToSize();
+ return result;
+ }
- @Override
- public T remove(int arg0) {
- T result = super.remove(arg0);
- trimToSize();
- return result;
- }
+ @Override
+ public void clear() {
+ super.clear();
+ trimToSize();
+ }
- @Override
- public boolean remove(Object arg0) {
- boolean result = super.remove(arg0);
- trimToSize();
- return result;
- }
+ @Override
+ public T remove(int arg0) {
+ T result = super.remove(arg0);
+ trimToSize();
+ return result;
+ }
- @Override
- protected void removeRange(int arg0, int arg1) {
- super.removeRange(arg0, arg1);
- trimToSize();
- }
+ @Override
+ public boolean remove(Object arg0) {
+ boolean result = super.remove(arg0);
+ trimToSize();
+ return result;
+ }
- @Override
- public MOAObject copy() {
- return AbstractMOAObject.copy(this);
- }
+ @Override
+ protected void removeRange(int arg0, int arg1) {
+ super.removeRange(arg0, arg1);
+ trimToSize();
+ }
- @Override
- public int measureByteSize() {
- return AbstractMOAObject.measureByteSize(this);
- }
+ @Override
+ public MOAObject copy() {
+ return AbstractMOAObject.copy(this);
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public int measureByteSize() {
+ return AbstractMOAObject.measureByteSize(this);
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/DataPoint.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/DataPoint.java
index 21ad3df..a5a82af 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/DataPoint.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/DataPoint.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.core;
/*
@@ -29,106 +28,107 @@
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
-public class DataPoint extends DenseInstance{
-
- private static final long serialVersionUID = 1L;
-
- protected int timestamp;
- private HashMap<String, String> measure_values;
-
- protected int noiseLabel;
+public class DataPoint extends DenseInstance {
- public DataPoint(Instance nextInstance, Integer timestamp) {
- super(nextInstance);
- this.setDataset(nextInstance.dataset());
- this.timestamp = timestamp;
- measure_values = new HashMap<String, String>();
-
- Attribute classLabel = dataset().classAttribute();
- noiseLabel = classLabel.indexOfValue("noise"); // -1 returned if there is no noise
+ private static final long serialVersionUID = 1L;
+
+ protected int timestamp;
+ private HashMap<String, String> measure_values;
+
+ protected int noiseLabel;
+
+ public DataPoint(Instance nextInstance, Integer timestamp) {
+ super(nextInstance);
+ this.setDataset(nextInstance.dataset());
+ this.timestamp = timestamp;
+ measure_values = new HashMap<String, String>();
+
+ Attribute classLabel = dataset().classAttribute();
+ noiseLabel = classLabel.indexOfValue("noise"); // -1 returned if there is no
+ // noise
+ }
+
+ public void updateWeight(int cur_timestamp, double decay_rate) {
+ setWeight(Math.pow(2, (-1.0) * decay_rate * (cur_timestamp - timestamp)));
+ }
+
+ public void setMeasureValue(String measureKey, double value) {
+ synchronized (measure_values) {
+ measure_values.put(measureKey, Double.toString(value));
+ }
+ }
+
+ public void setMeasureValue(String measureKey, String value) {
+ synchronized (measure_values) {
+ measure_values.put(measureKey, value);
+ }
+ }
+
+ public String getMeasureValue(String measureKey) {
+ if (measure_values.containsKey(measureKey))
+ synchronized (measure_values) {
+ return measure_values.get(measureKey);
+ }
+ else
+ return "";
+ }
+
+ public int getTimestamp() {
+ return timestamp;
+ }
+
+ public String getInfo(int x_dim, int y_dim) {
+ StringBuffer sb = new StringBuffer();
+ sb.append("<html><table>");
+ sb.append("<tr><td>Point</td><td>" + timestamp + "</td></tr>");
+ for (int i = 0; i < numAttributes() - 1; i++) { // m_AttValues.length
+ String label = "Dim " + i;
+ if (i == x_dim)
+ label = "<b>X</b>";
+ if (i == y_dim)
+ label = "<b>Y</b>";
+ sb.append("<tr><td>" + label + "</td><td>" + value(i) + "</td></tr>");
+ }
+ sb.append("<tr><td>Decay</td><td>" + weight() + "</td></tr>");
+ sb.append("<tr><td>True cluster</td><td>" + classValue() + "</td></tr>");
+ sb.append("</table>");
+ sb.append("<br>");
+ sb.append("<b>Evaluation</b><br>");
+ sb.append("<table>");
+
+ TreeSet<String> sortedset;
+ synchronized (measure_values) {
+ sortedset = new TreeSet<String>(measure_values.keySet());
}
- public void updateWeight(int cur_timestamp, double decay_rate){
- setWeight(Math.pow(2,(-1.0)*decay_rate*(cur_timestamp-timestamp)));
+ Iterator miterator = sortedset.iterator();
+ while (miterator.hasNext()) {
+ String key = (String) miterator.next();
+ sb.append("<tr><td>" + key + "</td><td>" + measure_values.get(key) + "</td></tr>");
}
- public void setMeasureValue(String measureKey, double value){
- synchronized(measure_values){
- measure_values.put(measureKey, Double.toString(value));
- }
+ sb.append("</table></html>");
+ return sb.toString();
+ }
+
+ public double getDistance(DataPoint other) {
+ double distance = 0.0;
+ int numDims = numAttributes();
+ if (classIndex() != 0)
+ numDims--;
+
+ for (int i = 0; i < numDims; i++) {
+ double d = value(i) - other.value(i);
+ distance += d * d;
}
+ return Math.sqrt(distance);
+ }
- public void setMeasureValue(String measureKey,String value){
- synchronized(measure_values){
- measure_values.put(measureKey, value);
- }
- }
+ public boolean isNoise() {
+ return (int) classValue() == noiseLabel;
+ }
- public String getMeasureValue(String measureKey){
- if(measure_values.containsKey(measureKey))
- synchronized(measure_values){
- return measure_values.get(measureKey);
- }
- else
- return "";
- }
-
- public int getTimestamp(){
- return timestamp;
- }
-
- public String getInfo(int x_dim, int y_dim) {
- StringBuffer sb = new StringBuffer();
- sb.append("<html><table>");
- sb.append("<tr><td>Point</td><td>"+timestamp+"</td></tr>");
- for (int i = 0; i < numAttributes() - 1; i++) { //m_AttValues.length
- String label = "Dim "+i;
- if(i == x_dim)
- label = "<b>X</b>";
- if(i == y_dim)
- label = "<b>Y</b>";
- sb.append("<tr><td>"+label+"</td><td>"+value(i)+"</td></tr>");
- }
- sb.append("<tr><td>Decay</td><td>"+weight()+"</td></tr>");
- sb.append("<tr><td>True cluster</td><td>"+classValue()+"</td></tr>");
- sb.append("</table>");
- sb.append("<br>");
- sb.append("<b>Evaluation</b><br>");
- sb.append("<table>");
-
- TreeSet<String> sortedset;
- synchronized(measure_values){
- sortedset = new TreeSet<String>(measure_values.keySet());
- }
-
- Iterator miterator = sortedset.iterator();
- while(miterator.hasNext()) {
- String key = (String)miterator.next();
- sb.append("<tr><td>"+key+"</td><td>"+measure_values.get(key)+"</td></tr>");
- }
-
- sb.append("</table></html>");
- return sb.toString();
- }
-
- public double getDistance(DataPoint other){
- double distance = 0.0;
- int numDims = numAttributes();
- if(classIndex()!=0) numDims--;
-
- for (int i = 0; i < numDims; i++) {
- double d = value(i) - other.value(i);
- distance += d * d;
- }
- return Math.sqrt(distance);
- }
-
-
- public boolean isNoise() {
- return (int)classValue() == noiseLabel;
- }
-
- public double getNoiseLabel() {
- return noiseLabel;
- }
+ public double getNoiseLabel() {
+ return noiseLabel;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/DoubleVector.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/DoubleVector.java
index 1373d22..0c4d4a6 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/DoubleVector.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/DoubleVector.java
@@ -24,172 +24,172 @@
/**
* Vector of double numbers with some utilities.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class DoubleVector extends AbstractMOAObject {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected double[] array;
+ protected double[] array;
- public DoubleVector() {
- this.array = new double[0];
+ public DoubleVector() {
+ this.array = new double[0];
+ }
+
+ public DoubleVector(double[] toCopy) {
+ this.array = new double[toCopy.length];
+ System.arraycopy(toCopy, 0, this.array, 0, toCopy.length);
+ }
+
+ public DoubleVector(DoubleVector toCopy) {
+ this(toCopy.getArrayRef());
+ }
+
+ public int numValues() {
+ return this.array.length;
+ }
+
+ public void setValue(int i, double v) {
+ if (i >= this.array.length) {
+ setArrayLength(i + 1);
}
+ this.array[i] = v;
+ }
- public DoubleVector(double[] toCopy) {
- this.array = new double[toCopy.length];
- System.arraycopy(toCopy, 0, this.array, 0, toCopy.length);
+ public void addToValue(int i, double v) {
+ if (i >= this.array.length) {
+ setArrayLength(i + 1);
}
+ this.array[i] += v;
+ }
- public DoubleVector(DoubleVector toCopy) {
- this(toCopy.getArrayRef());
+ public void addValues(DoubleVector toAdd) {
+ addValues(toAdd.getArrayRef());
+ }
+
+ public void addValues(double[] toAdd) {
+ if (toAdd.length > this.array.length) {
+ setArrayLength(toAdd.length);
}
-
- public int numValues() {
- return this.array.length;
+ for (int i = 0; i < toAdd.length; i++) {
+ this.array[i] += toAdd[i];
}
+ }
- public void setValue(int i, double v) {
- if (i >= this.array.length) {
- setArrayLength(i + 1);
+ public void subtractValues(DoubleVector toSubtract) {
+ subtractValues(toSubtract.getArrayRef());
+ }
+
+ public void subtractValues(double[] toSubtract) {
+ if (toSubtract.length > this.array.length) {
+ setArrayLength(toSubtract.length);
+ }
+ for (int i = 0; i < toSubtract.length; i++) {
+ this.array[i] -= toSubtract[i];
+ }
+ }
+
+ public void addToValues(double toAdd) {
+ for (int i = 0; i < this.array.length; i++) {
+ this.array[i] = this.array[i] + toAdd;
+ }
+ }
+
+ public void scaleValues(double multiplier) {
+ for (int i = 0; i < this.array.length; i++) {
+ this.array[i] = this.array[i] * multiplier;
+ }
+ }
+
+ // returns 0.0 for values outside of range
+ public double getValue(int i) {
+ return ((i >= 0) && (i < this.array.length)) ? this.array[i] : 0.0;
+ }
+
+ public double sumOfValues() {
+ double sum = 0.0;
+ for (double element : this.array) {
+ sum += element;
+ }
+ return sum;
+ }
+
+ public int maxIndex() {
+ int max = -1;
+ for (int i = 0; i < this.array.length; i++) {
+ if ((max < 0) || (this.array[i] > this.array[max])) {
+ max = i;
+ }
+ }
+ return max;
+ }
+
+ public void normalize() {
+ scaleValues(1.0 / sumOfValues());
+ }
+
+ public int numNonZeroEntries() {
+ int count = 0;
+ for (double element : this.array) {
+ if (element != 0.0) {
+ count++;
+ }
+ }
+ return count;
+ }
+
+ public double minWeight() {
+ if (this.array.length > 0) {
+ double min = this.array[0];
+ for (int i = 1; i < this.array.length; i++) {
+ if (this.array[i] < min) {
+ min = this.array[i];
}
- this.array[i] = v;
+ }
+ return min;
}
+ return 0.0;
+ }
- public void addToValue(int i, double v) {
- if (i >= this.array.length) {
- setArrayLength(i + 1);
- }
- this.array[i] += v;
- }
+ public double[] getArrayCopy() {
+ double[] aCopy = new double[this.array.length];
+ System.arraycopy(this.array, 0, aCopy, 0, this.array.length);
+ return aCopy;
+ }
- public void addValues(DoubleVector toAdd) {
- addValues(toAdd.getArrayRef());
- }
+ public double[] getArrayRef() {
+ return this.array;
+ }
- public void addValues(double[] toAdd) {
- if (toAdd.length > this.array.length) {
- setArrayLength(toAdd.length);
- }
- for (int i = 0; i < toAdd.length; i++) {
- this.array[i] += toAdd[i];
- }
+ protected void setArrayLength(int l) {
+ double[] newArray = new double[l];
+ int numToCopy = this.array.length;
+ if (numToCopy > l) {
+ numToCopy = l;
}
+ System.arraycopy(this.array, 0, newArray, 0, numToCopy);
+ this.array = newArray;
+ }
- public void subtractValues(DoubleVector toSubtract) {
- subtractValues(toSubtract.getArrayRef());
- }
+ public void getSingleLineDescription(StringBuilder out) {
+ getSingleLineDescription(out, numValues());
+ }
- public void subtractValues(double[] toSubtract) {
- if (toSubtract.length > this.array.length) {
- setArrayLength(toSubtract.length);
- }
- for (int i = 0; i < toSubtract.length; i++) {
- this.array[i] -= toSubtract[i];
- }
+ public void getSingleLineDescription(StringBuilder out, int numValues) {
+ out.append("{");
+ for (int i = 0; i < numValues; i++) {
+ if (i > 0) {
+ out.append("|");
+ }
+ out.append(StringUtils.doubleToString(getValue(i), 3));
}
+ out.append("}");
+ }
- public void addToValues(double toAdd) {
- for (int i = 0; i < this.array.length; i++) {
- this.array[i] = this.array[i] + toAdd;
- }
- }
-
- public void scaleValues(double multiplier) {
- for (int i = 0; i < this.array.length; i++) {
- this.array[i] = this.array[i] * multiplier;
- }
- }
-
- // returns 0.0 for values outside of range
- public double getValue(int i) {
- return ((i >= 0) && (i < this.array.length)) ? this.array[i] : 0.0;
- }
-
- public double sumOfValues() {
- double sum = 0.0;
- for (double element : this.array) {
- sum += element;
- }
- return sum;
- }
-
- public int maxIndex() {
- int max = -1;
- for (int i = 0; i < this.array.length; i++) {
- if ((max < 0) || (this.array[i] > this.array[max])) {
- max = i;
- }
- }
- return max;
- }
-
- public void normalize() {
- scaleValues(1.0 / sumOfValues());
- }
-
- public int numNonZeroEntries() {
- int count = 0;
- for (double element : this.array) {
- if (element != 0.0) {
- count++;
- }
- }
- return count;
- }
-
- public double minWeight() {
- if (this.array.length > 0) {
- double min = this.array[0];
- for (int i = 1; i < this.array.length; i++) {
- if (this.array[i] < min) {
- min = this.array[i];
- }
- }
- return min;
- }
- return 0.0;
- }
-
- public double[] getArrayCopy() {
- double[] aCopy = new double[this.array.length];
- System.arraycopy(this.array, 0, aCopy, 0, this.array.length);
- return aCopy;
- }
-
- public double[] getArrayRef() {
- return this.array;
- }
-
- protected void setArrayLength(int l) {
- double[] newArray = new double[l];
- int numToCopy = this.array.length;
- if (numToCopy > l) {
- numToCopy = l;
- }
- System.arraycopy(this.array, 0, newArray, 0, numToCopy);
- this.array = newArray;
- }
-
- public void getSingleLineDescription(StringBuilder out) {
- getSingleLineDescription(out, numValues());
- }
-
- public void getSingleLineDescription(StringBuilder out, int numValues) {
- out.append("{");
- for (int i = 0; i < numValues; i++) {
- if (i > 0) {
- out.append("|");
- }
- out.append(StringUtils.doubleToString(getValue(i), 3));
- }
- out.append("}");
- }
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- getSingleLineDescription(sb);
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ getSingleLineDescription(sb);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Example.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Example.java
index 658ea14..1693350 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Example.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Example.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.core;
/*
@@ -21,11 +20,11 @@
* #L%
*/
-public interface Example< T extends Object> {
+public interface Example<T extends Object> {
- public T getData();
+ public T getData();
- public double weight();
-
- public void setWeight(double weight);
-}
+ public double weight();
+
+ public void setWeight(double weight);
+}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/FastVector.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/FastVector.java
index 102c6b1..fb03c40 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/FastVector.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/FastVector.java
@@ -1,4 +1,3 @@
-
/*
* FastVector.java
@@ -30,38 +29,41 @@
/**
* Simple extension of ArrayList. Exists for legacy reasons.
- *
+ *
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision: 8034 $
*/
public class FastVector<E> extends ArrayList<E> {
- /**
- * Adds an element to this vector. Increases its capacity if its not large
- * enough.
- *
- * @param element the element to add
- */
- public final void addElement(E element) {
- add(element);
- }
+ /**
+ * Adds an element to this vector. Increases its capacity if its not large
+ * enough.
+ *
+ * @param element
+ * the element to add
+ */
+ public final void addElement(E element) {
+ add(element);
+ }
- /**
- * Returns the element at the given position.
- *
- * @param index the element's index
- * @return the element with the given index
- */
- public final E elementAt(int index) {
- return get(index);
- }
+ /**
+ * Returns the element at the given position.
+ *
+ * @param index
+ * the element's index
+ * @return the element with the given index
+ */
+ public final E elementAt(int index) {
+ return get(index);
+ }
- /**
- * Deletes an element from this vector.
- *
- * @param index the index of the element to be deleted
- */
- public final void removeElementAt(int index) {
- remove(index);
- }
+ /**
+ * Deletes an element from this vector.
+ *
+ * @param index
+ * the index of the element to be deleted
+ */
+ public final void removeElementAt(int index) {
+ remove(index);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/GaussianEstimator.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/GaussianEstimator.java
index e784e29..6677bf0 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/GaussianEstimator.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/GaussianEstimator.java
@@ -23,101 +23,105 @@
import com.yahoo.labs.samoa.moa.AbstractMOAObject;
/**
- * Gaussian incremental estimator that uses incremental method that is more resistant to floating point imprecision.
- * for more info see Donald Knuth's "The Art of Computer Programming, Volume 2: Seminumerical Algorithms", section 4.2.2.
- *
+ * Gaussian incremental estimator that uses incremental method that is more
+ * resistant to floating point imprecision. for more info see Donald Knuth's
+ * "The Art of Computer Programming, Volume 2: Seminumerical Algorithms",
+ * section 4.2.2.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class GaussianEstimator extends AbstractMOAObject {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected double weightSum;
+ protected double weightSum;
- protected double mean;
+ protected double mean;
- protected double varianceSum;
+ protected double varianceSum;
- public static final double NORMAL_CONSTANT = Math.sqrt(2 * Math.PI);
+ public static final double NORMAL_CONSTANT = Math.sqrt(2 * Math.PI);
- public void addObservation(double value, double weight) {
- if (Double.isInfinite(value) || Double.isNaN(value)) {
- return;
- }
- if (this.weightSum > 0.0) {
- this.weightSum += weight;
- double lastMean = this.mean;
- this.mean += weight * (value - lastMean) / this.weightSum;
- this.varianceSum += weight * (value - lastMean) * (value - this.mean);
- } else {
- this.mean = value;
- this.weightSum = weight;
- }
+ public void addObservation(double value, double weight) {
+ if (Double.isInfinite(value) || Double.isNaN(value)) {
+ return;
}
-
- public void addObservations(GaussianEstimator obs) {
- // Follows Variance Combination Rule in Section 2 of
- // Brian Babcock, Mayur Datar, Rajeev Motwani, Liadan O'Callaghan:
- // Maintaining variance and k-medians over data stream windows. PODS 2003: 234-243
- //
- if ((this.weightSum >= 0.0) && (obs.weightSum > 0.0)) {
- double oldMean = this.mean;
- this.mean = (this.mean * (this.weightSum / (this.weightSum + obs.weightSum)))
- + (obs.mean * (obs.weightSum / (this.weightSum + obs.weightSum)));
- this.varianceSum += obs.varianceSum + (this.weightSum * obs.weightSum / (this.weightSum + obs.weightSum) *
- Math.pow(obs.mean-oldMean, 2));
- this.weightSum += obs.weightSum;
- }
+ if (this.weightSum > 0.0) {
+ this.weightSum += weight;
+ double lastMean = this.mean;
+ this.mean += weight * (value - lastMean) / this.weightSum;
+ this.varianceSum += weight * (value - lastMean) * (value - this.mean);
+ } else {
+ this.mean = value;
+ this.weightSum = weight;
}
+ }
- public double getTotalWeightObserved() {
- return this.weightSum;
+ public void addObservations(GaussianEstimator obs) {
+ // Follows Variance Combination Rule in Section 2 of
+ // Brian Babcock, Mayur Datar, Rajeev Motwani, Liadan O'Callaghan:
+ // Maintaining variance and k-medians over data stream windows. PODS 2003:
+ // 234-243
+ //
+ if ((this.weightSum >= 0.0) && (obs.weightSum > 0.0)) {
+ double oldMean = this.mean;
+ this.mean = (this.mean * (this.weightSum / (this.weightSum + obs.weightSum)))
+ + (obs.mean * (obs.weightSum / (this.weightSum + obs.weightSum)));
+ this.varianceSum += obs.varianceSum + (this.weightSum * obs.weightSum / (this.weightSum + obs.weightSum) *
+ Math.pow(obs.mean - oldMean, 2));
+ this.weightSum += obs.weightSum;
}
+ }
- public double getMean() {
- return this.mean;
- }
+ public double getTotalWeightObserved() {
+ return this.weightSum;
+ }
- public double getStdDev() {
- return Math.sqrt(getVariance());
- }
+ public double getMean() {
+ return this.mean;
+ }
- public double getVariance() {
- return this.weightSum > 1.0 ? this.varianceSum / (this.weightSum - 1.0)
- : 0.0;
- }
+ public double getStdDev() {
+ return Math.sqrt(getVariance());
+ }
- public double probabilityDensity(double value) {
- if (this.weightSum > 0.0) {
- double stdDev = getStdDev();
- if (stdDev > 0.0) {
- double diff = value - getMean();
- return (1.0 / (NORMAL_CONSTANT * stdDev))
- * Math.exp(-(diff * diff / (2.0 * stdDev * stdDev)));
- }
- return value == getMean() ? 1.0 : 0.0;
- }
- return 0.0;
- }
+ public double getVariance() {
+ return this.weightSum > 1.0 ? this.varianceSum / (this.weightSum - 1.0)
+ : 0.0;
+ }
- public double[] estimatedWeight_LessThan_EqualTo_GreaterThan_Value(
- double value) {
- double equalToWeight = probabilityDensity(value) * this.weightSum;
- double stdDev = getStdDev();
- double lessThanWeight = stdDev > 0.0 ? com.yahoo.labs.samoa.moa.core.Statistics.normalProbability((value - getMean()) / stdDev)
- * this.weightSum - equalToWeight
- : (value < getMean() ? this.weightSum - equalToWeight : 0.0);
- double greaterThanWeight = this.weightSum - equalToWeight
- - lessThanWeight;
- if (greaterThanWeight < 0.0) {
- greaterThanWeight = 0.0;
- }
- return new double[]{lessThanWeight, equalToWeight, greaterThanWeight};
+ public double probabilityDensity(double value) {
+ if (this.weightSum > 0.0) {
+ double stdDev = getStdDev();
+ if (stdDev > 0.0) {
+ double diff = value - getMean();
+ return (1.0 / (NORMAL_CONSTANT * stdDev))
+ * Math.exp(-(diff * diff / (2.0 * stdDev * stdDev)));
+ }
+ return value == getMean() ? 1.0 : 0.0;
}
+ return 0.0;
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
+ public double[] estimatedWeight_LessThan_EqualTo_GreaterThan_Value(
+ double value) {
+ double equalToWeight = probabilityDensity(value) * this.weightSum;
+ double stdDev = getStdDev();
+ double lessThanWeight = stdDev > 0.0 ? com.yahoo.labs.samoa.moa.core.Statistics
+ .normalProbability((value - getMean()) / stdDev)
+ * this.weightSum - equalToWeight
+ : (value < getMean() ? this.weightSum - equalToWeight : 0.0);
+ double greaterThanWeight = this.weightSum - equalToWeight
+ - lessThanWeight;
+ if (greaterThanWeight < 0.0) {
+ greaterThanWeight = 0.0;
}
+ return new double[] { lessThanWeight, equalToWeight, greaterThanWeight };
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/GreenwaldKhannaQuantileSummary.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/GreenwaldKhannaQuantileSummary.java
index 59af67d..43bce69 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/GreenwaldKhannaQuantileSummary.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/GreenwaldKhannaQuantileSummary.java
@@ -27,255 +27,255 @@
/**
* Class for representing summaries of Greenwald and Khanna quantiles.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class GreenwaldKhannaQuantileSummary extends AbstractMOAObject {
+ private static final long serialVersionUID = 1L;
+
+ protected static class Tuple implements Serializable {
+
private static final long serialVersionUID = 1L;
- protected static class Tuple implements Serializable {
+ public double v;
- private static final long serialVersionUID = 1L;
+ public long g;
- public double v;
+ public long delta;
- public long g;
+ public Tuple(double v, long g, long delta) {
+ this.v = v;
+ this.g = g;
+ this.delta = delta;
+ }
- public long delta;
+ public Tuple(double v) {
+ this(v, 1, 0);
+ }
+ }
- public Tuple(double v, long g, long delta) {
- this.v = v;
- this.g = g;
- this.delta = delta;
+ protected Tuple[] summary;
+
+ protected int numTuples = 0;
+
+ protected long numObservations = 0;
+
+ public GreenwaldKhannaQuantileSummary(int maxTuples) {
+ this.summary = new Tuple[maxTuples];
+ }
+
+ public void insert(double val) {
+ int i = findIndexOfTupleGreaterThan(val);
+ Tuple nextT = this.summary[i];
+ if (nextT == null) {
+ insertTuple(new Tuple(val, 1, 0), i);
+ } else {
+ insertTuple(new Tuple(val, 1, nextT.g + nextT.delta - 1), i);
+ }
+ if (this.numTuples == this.summary.length) {
+ // use method 1
+ deleteMergeableTupleMostFull();
+ // if (mergeMethod == 1) {
+ // deleteMergeableTupleMostFull();
+ // } else if (mergeMethod == 2) {
+ // deleteTupleMostFull();
+ // } else {
+ // long maxDelta = findMaxDelta();
+ // compress(maxDelta);
+ // while (numTuples == summary.length) {
+ // maxDelta++;
+ // compress(maxDelta);
+ // }
+ // }
+ }
+ this.numObservations++;
+ }
+
+ protected void insertTuple(Tuple t, int index) {
+ System.arraycopy(this.summary, index, this.summary, index + 1,
+ this.numTuples - index);
+ this.summary[index] = t;
+ this.numTuples++;
+ }
+
+ protected void deleteTuple(int index) {
+ this.summary[index] = new Tuple(this.summary[index + 1].v,
+ this.summary[index].g + this.summary[index + 1].g,
+ this.summary[index + 1].delta);
+ System.arraycopy(this.summary, index + 2, this.summary, index + 1,
+ this.numTuples - index - 2);
+ this.summary[this.numTuples - 1] = null;
+ this.numTuples--;
+ }
+
+ protected void deleteTupleMostFull() {
+ long leastFullness = Long.MAX_VALUE;
+ int leastFullIndex = 0;
+ for (int i = 1; i < this.numTuples - 1; i++) {
+ long fullness = this.summary[i].g + this.summary[i + 1].g
+ + this.summary[i + 1].delta;
+ if (fullness < leastFullness) {
+ leastFullness = fullness;
+ leastFullIndex = i;
+ }
+ }
+ if (leastFullIndex > 0) {
+ deleteTuple(leastFullIndex);
+ }
+ }
+
+ protected void deleteMergeableTupleMostFull() {
+ long leastFullness = Long.MAX_VALUE;
+ int leastFullIndex = 0;
+ for (int i = 1; i < this.numTuples - 1; i++) {
+ long fullness = this.summary[i].g + this.summary[i + 1].g
+ + this.summary[i + 1].delta;
+ if ((this.summary[i].delta >= this.summary[i + 1].delta)
+ && (fullness < leastFullness)) {
+ leastFullness = fullness;
+ leastFullIndex = i;
+ }
+ }
+ if (leastFullIndex > 0) {
+ deleteTuple(leastFullIndex);
+ }
+ }
+
+ public long getWorstError() {
+ long mostFullness = 0;
+ for (int i = 1; i < this.numTuples - 1; i++) {
+ long fullness = this.summary[i].g + this.summary[i].delta;
+ if (fullness > mostFullness) {
+ mostFullness = fullness;
+ }
+ }
+ return mostFullness;
+ }
+
+ public long findMaxDelta() {
+ long maxDelta = 0;
+ for (int i = 0; i < this.numTuples; i++) {
+ if (this.summary[i].delta > maxDelta) {
+ maxDelta = this.summary[i].delta;
+ }
+ }
+ return maxDelta;
+ }
+
+ public void compress(long maxDelta) {
+ long[] bandBoundaries = computeBandBoundaries(maxDelta);
+ for (int i = this.numTuples - 2; i >= 0; i--) {
+ if (this.summary[i].delta >= this.summary[i + 1].delta) {
+ int band = 0;
+ while (this.summary[i].delta < bandBoundaries[band]) {
+ band++;
}
-
- public Tuple(double v) {
- this(v, 1, 0);
+ long belowBandThreshold = Long.MAX_VALUE;
+ if (band > 0) {
+ belowBandThreshold = bandBoundaries[band - 1];
}
- }
-
- protected Tuple[] summary;
-
- protected int numTuples = 0;
-
- protected long numObservations = 0;
-
- public GreenwaldKhannaQuantileSummary(int maxTuples) {
- this.summary = new Tuple[maxTuples];
- }
-
- public void insert(double val) {
- int i = findIndexOfTupleGreaterThan(val);
- Tuple nextT = this.summary[i];
- if (nextT == null) {
- insertTuple(new Tuple(val, 1, 0), i);
- } else {
- insertTuple(new Tuple(val, 1, nextT.g + nextT.delta - 1), i);
+ long mergeG = this.summary[i + 1].g + this.summary[i].g;
+ int childI = i - 1;
+ while (((mergeG + this.summary[i + 1].delta) < maxDelta)
+ && (childI >= 0)
+ && (this.summary[childI].delta >= belowBandThreshold)) {
+ mergeG += this.summary[childI].g;
+ childI--;
}
- if (this.numTuples == this.summary.length) {
- // use method 1
- deleteMergeableTupleMostFull();
- // if (mergeMethod == 1) {
- // deleteMergeableTupleMostFull();
- // } else if (mergeMethod == 2) {
- // deleteTupleMostFull();
- // } else {
- // long maxDelta = findMaxDelta();
- // compress(maxDelta);
- // while (numTuples == summary.length) {
- // maxDelta++;
- // compress(maxDelta);
- // }
- // }
+ if (mergeG + this.summary[i + 1].delta < maxDelta) {
+ // merge
+ int numDeleted = i - childI;
+ this.summary[childI + 1] = new Tuple(this.summary[i + 1].v,
+ mergeG, this.summary[i + 1].delta);
+ // todo complete & test this multiple delete
+ System.arraycopy(this.summary, i + 2, this.summary,
+ childI + 2, this.numTuples - (i + 2));
+ for (int j = this.numTuples - numDeleted; j < this.numTuples; j++) {
+ this.summary[j] = null;
+ }
+ this.numTuples -= numDeleted;
+ i = childI + 1;
}
- this.numObservations++;
+ }
}
+ }
- protected void insertTuple(Tuple t, int index) {
- System.arraycopy(this.summary, index, this.summary, index + 1,
- this.numTuples - index);
- this.summary[index] = t;
- this.numTuples++;
+ public double getQuantile(double quant) {
+ long r = (long) Math.ceil(quant * this.numObservations);
+ long currRank = 0;
+ for (int i = 0; i < this.numTuples - 1; i++) {
+ currRank += this.summary[i].g;
+ if (currRank + this.summary[i + 1].g > r) {
+ return this.summary[i].v;
+ }
}
+ return this.summary[this.numTuples - 1].v;
+ }
- protected void deleteTuple(int index) {
- this.summary[index] = new Tuple(this.summary[index + 1].v,
- this.summary[index].g + this.summary[index + 1].g,
- this.summary[index + 1].delta);
- System.arraycopy(this.summary, index + 2, this.summary, index + 1,
- this.numTuples - index - 2);
- this.summary[this.numTuples - 1] = null;
- this.numTuples--;
- }
+ public long getTotalCount() {
+ return this.numObservations;
+ }
- protected void deleteTupleMostFull() {
- long leastFullness = Long.MAX_VALUE;
- int leastFullIndex = 0;
- for (int i = 1; i < this.numTuples - 1; i++) {
- long fullness = this.summary[i].g + this.summary[i + 1].g
- + this.summary[i + 1].delta;
- if (fullness < leastFullness) {
- leastFullness = fullness;
- leastFullIndex = i;
- }
- }
- if (leastFullIndex > 0) {
- deleteTuple(leastFullIndex);
- }
- }
+ public double getPropotionBelow(double cutpoint) {
+ return (double) getCountBelow(cutpoint) / (double) this.numObservations;
+ }
- protected void deleteMergeableTupleMostFull() {
- long leastFullness = Long.MAX_VALUE;
- int leastFullIndex = 0;
- for (int i = 1; i < this.numTuples - 1; i++) {
- long fullness = this.summary[i].g + this.summary[i + 1].g
- + this.summary[i + 1].delta;
- if ((this.summary[i].delta >= this.summary[i + 1].delta)
- && (fullness < leastFullness)) {
- leastFullness = fullness;
- leastFullIndex = i;
- }
- }
- if (leastFullIndex > 0) {
- deleteTuple(leastFullIndex);
- }
+ public long getCountBelow(double cutpoint) {
+ long rank = 0;
+ for (int i = 0; i < this.numTuples; i++) {
+ if (this.summary[i].v > cutpoint) {
+ break;
+ }
+ rank += this.summary[i].g;
}
+ return rank;
+ }
- public long getWorstError() {
- long mostFullness = 0;
- for (int i = 1; i < this.numTuples - 1; i++) {
- long fullness = this.summary[i].g + this.summary[i].delta;
- if (fullness > mostFullness) {
- mostFullness = fullness;
- }
- }
- return mostFullness;
+ public double[] getSuggestedCutpoints() {
+ double[] cutpoints = new double[this.numTuples];
+ for (int i = 0; i < this.numTuples; i++) {
+ cutpoints[i] = this.summary[i].v;
}
+ return cutpoints;
+ }
- public long findMaxDelta() {
- long maxDelta = 0;
- for (int i = 0; i < this.numTuples; i++) {
- if (this.summary[i].delta > maxDelta) {
- maxDelta = this.summary[i].delta;
- }
- }
- return maxDelta;
+ protected int findIndexOfTupleGreaterThan(double val) {
+ int high = this.numTuples, low = -1, probe;
+ while (high - low > 1) {
+ probe = (high + low) / 2;
+ if (this.summary[probe].v > val) {
+ high = probe;
+ } else {
+ low = probe;
+ }
}
+ return high;
+ }
- public void compress(long maxDelta) {
- long[] bandBoundaries = computeBandBoundaries(maxDelta);
- for (int i = this.numTuples - 2; i >= 0; i--) {
- if (this.summary[i].delta >= this.summary[i + 1].delta) {
- int band = 0;
- while (this.summary[i].delta < bandBoundaries[band]) {
- band++;
- }
- long belowBandThreshold = Long.MAX_VALUE;
- if (band > 0) {
- belowBandThreshold = bandBoundaries[band - 1];
- }
- long mergeG = this.summary[i + 1].g + this.summary[i].g;
- int childI = i - 1;
- while (((mergeG + this.summary[i + 1].delta) < maxDelta)
- && (childI >= 0)
- && (this.summary[childI].delta >= belowBandThreshold)) {
- mergeG += this.summary[childI].g;
- childI--;
- }
- if (mergeG + this.summary[i + 1].delta < maxDelta) {
- // merge
- int numDeleted = i - childI;
- this.summary[childI + 1] = new Tuple(this.summary[i + 1].v,
- mergeG, this.summary[i + 1].delta);
- // todo complete & test this multiple delete
- System.arraycopy(this.summary, i + 2, this.summary,
- childI + 2, this.numTuples - (i + 2));
- for (int j = this.numTuples - numDeleted; j < this.numTuples; j++) {
- this.summary[j] = null;
- }
- this.numTuples -= numDeleted;
- i = childI + 1;
- }
- }
- }
+ public static long[] computeBandBoundaries(long maxDelta) {
+ ArrayList<Long> boundaryList = new ArrayList<Long>();
+ boundaryList.add(new Long(maxDelta));
+ int alpha = 1;
+ while (true) {
+ long boundary = (maxDelta - (2 << (alpha - 1)) - (maxDelta % (2 << (alpha - 1))));
+ if (boundary >= 0) {
+ boundaryList.add(new Long(boundary + 1));
+ } else {
+ break;
+ }
+ alpha++;
}
+ boundaryList.add(new Long(0));
+ long[] boundaries = new long[boundaryList.size()];
+ for (int i = 0; i < boundaries.length; i++) {
+ boundaries[i] = boundaryList.get(i).longValue();
+ }
+ return boundaries;
+ }
- public double getQuantile(double quant) {
- long r = (long) Math.ceil(quant * this.numObservations);
- long currRank = 0;
- for (int i = 0; i < this.numTuples - 1; i++) {
- currRank += this.summary[i].g;
- if (currRank + this.summary[i + 1].g > r) {
- return this.summary[i].v;
- }
- }
- return this.summary[this.numTuples - 1].v;
- }
-
- public long getTotalCount() {
- return this.numObservations;
- }
-
- public double getPropotionBelow(double cutpoint) {
- return (double) getCountBelow(cutpoint) / (double) this.numObservations;
- }
-
- public long getCountBelow(double cutpoint) {
- long rank = 0;
- for (int i = 0; i < this.numTuples; i++) {
- if (this.summary[i].v > cutpoint) {
- break;
- }
- rank += this.summary[i].g;
- }
- return rank;
- }
-
- public double[] getSuggestedCutpoints() {
- double[] cutpoints = new double[this.numTuples];
- for (int i = 0; i < this.numTuples; i++) {
- cutpoints[i] = this.summary[i].v;
- }
- return cutpoints;
- }
-
- protected int findIndexOfTupleGreaterThan(double val) {
- int high = this.numTuples, low = -1, probe;
- while (high - low > 1) {
- probe = (high + low) / 2;
- if (this.summary[probe].v > val) {
- high = probe;
- } else {
- low = probe;
- }
- }
- return high;
- }
-
- public static long[] computeBandBoundaries(long maxDelta) {
- ArrayList<Long> boundaryList = new ArrayList<Long>();
- boundaryList.add(new Long(maxDelta));
- int alpha = 1;
- while (true) {
- long boundary = (maxDelta - (2 << (alpha - 1)) - (maxDelta % (2 << (alpha - 1))));
- if (boundary >= 0) {
- boundaryList.add(new Long(boundary + 1));
- } else {
- break;
- }
- alpha++;
- }
- boundaryList.add(new Long(0));
- long[] boundaries = new long[boundaryList.size()];
- for (int i = 0; i < boundaries.length; i++) {
- boundaries[i] = boundaryList.get(i).longValue();
- }
- return boundaries;
- }
-
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/InputStreamProgressMonitor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/InputStreamProgressMonitor.java
index 9cda129..8b7dd87 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/InputStreamProgressMonitor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/InputStreamProgressMonitor.java
@@ -27,105 +27,105 @@
/**
* Class for monitoring the progress of reading an input stream.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class InputStreamProgressMonitor extends FilterInputStream implements Serializable {
- /** The number of bytes to read in total */
- protected int inputByteSize;
+ /** The number of bytes to read in total */
+ protected int inputByteSize;
- /** The number of bytes read so far */
- protected int inputBytesRead;
-
- public InputStreamProgressMonitor(InputStream in) {
- super(in);
- try {
- this.inputByteSize = in.available();
- } catch (IOException ioe) {
- this.inputByteSize = 0;
- }
- this.inputBytesRead = 0;
- }
-
- public int getBytesRead() {
- return this.inputBytesRead;
- }
+ /** The number of bytes read so far */
+ protected int inputBytesRead;
- public int getBytesRemaining() {
- return this.inputByteSize - this.inputBytesRead;
- }
+ public InputStreamProgressMonitor(InputStream in) {
+ super(in);
+ try {
+ this.inputByteSize = in.available();
+ } catch (IOException ioe) {
+ this.inputByteSize = 0;
+ }
+ this.inputBytesRead = 0;
+ }
- public double getProgressFraction() {
- return ((double) this.inputBytesRead / (double) this.inputByteSize);
- }
+ public int getBytesRead() {
+ return this.inputBytesRead;
+ }
- /*
- * (non-Javadoc)
- *
- * @see java.io.InputStream#read()
- */
- @Override
- public int read() throws IOException {
- int c = this.in.read();
- if (c > 0) {
- this.inputBytesRead++;
- }
- return c;
- }
+ public int getBytesRemaining() {
+ return this.inputByteSize - this.inputBytesRead;
+ }
- /*
- * (non-Javadoc)
- *
- * @see java.io.InputStream#read(byte[])
- */
- @Override
- public int read(byte[] b) throws IOException {
- int numread = this.in.read(b);
- if (numread > 0) {
- this.inputBytesRead += numread;
- }
- return numread;
- }
+ public double getProgressFraction() {
+ return ((double) this.inputBytesRead / (double) this.inputByteSize);
+ }
- /*
- * (non-Javadoc)
- *
- * @see java.io.InputStream#read(byte[], int, int)
- */
- @Override
- public int read(byte[] b, int off, int len) throws IOException {
- int numread = this.in.read(b, off, len);
- if (numread > 0) {
- this.inputBytesRead += numread;
- }
- return numread;
- }
+ /*
+ * (non-Javadoc)
+ *
+ * @see java.io.InputStream#read()
+ */
+ @Override
+ public int read() throws IOException {
+ int c = this.in.read();
+ if (c > 0) {
+ this.inputBytesRead++;
+ }
+ return c;
+ }
- /*
- * (non-Javadoc)
- *
- * @see java.io.InputStream#skip(long)
- */
- @Override
- public long skip(long n) throws IOException {
- long numskip = this.in.skip(n);
- if (numskip > 0) {
- this.inputBytesRead += numskip;
- }
- return numskip;
- }
+ /*
+ * (non-Javadoc)
+ *
+ * @see java.io.InputStream#read(byte[])
+ */
+ @Override
+ public int read(byte[] b) throws IOException {
+ int numread = this.in.read(b);
+ if (numread > 0) {
+ this.inputBytesRead += numread;
+ }
+ return numread;
+ }
- /*
- * (non-Javadoc)
- *
- * @see java.io.FilterInputStream#reset()
- */
- @Override
- public synchronized void reset() throws IOException {
- this.in.reset();
- this.inputBytesRead = this.inputByteSize - this.in.available();
- }
+ /*
+ * (non-Javadoc)
+ *
+ * @see java.io.InputStream#read(byte[], int, int)
+ */
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ int numread = this.in.read(b, off, len);
+ if (numread > 0) {
+ this.inputBytesRead += numread;
+ }
+ return numread;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see java.io.InputStream#skip(long)
+ */
+ @Override
+ public long skip(long n) throws IOException {
+ long numskip = this.in.skip(n);
+ if (numskip > 0) {
+ this.inputBytesRead += numskip;
+ }
+ return numskip;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see java.io.FilterInputStream#reset()
+ */
+ @Override
+ public synchronized void reset() throws IOException {
+ this.in.reset();
+ this.inputBytesRead = this.inputByteSize - this.in.available();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/InstanceExample.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/InstanceExample.java
index 20866a0..490ee0a 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/InstanceExample.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/InstanceExample.java
@@ -25,26 +25,26 @@
public class InstanceExample implements Example<Instance>, Serializable {
- public Instance instance;
+ public Instance instance;
- public InstanceExample (Instance inst)
- {
- this.instance = inst;
- }
+ public InstanceExample(Instance inst)
+ {
+ this.instance = inst;
+ }
- @Override
- public Instance getData() {
- return this.instance;
- }
-
- @Override
- public double weight() {
- return this.instance.weight();
- }
+ @Override
+ public Instance getData() {
+ return this.instance;
+ }
- @Override
- public void setWeight(double w) {
- this.instance.setWeight(w);
- }
+ @Override
+ public double weight() {
+ return this.instance.weight();
+ }
-}
+ @Override
+ public void setWeight(double w) {
+ this.instance.setWeight(w);
+ }
+
+}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Measurement.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Measurement.java
index 1d1f194..5b9e042 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Measurement.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Measurement.java
@@ -27,89 +27,89 @@
/**
* Class for storing an evaluation measurement.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class Measurement extends AbstractMOAObject {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected String name;
+ protected String name;
- protected double value;
+ protected double value;
- public Measurement(String name, double value) {
- this.name = name;
- this.value = value;
+ public Measurement(String name, double value) {
+ this.name = name;
+ this.value = value;
+ }
+
+ public String getName() {
+ return this.name;
+ }
+
+ public double getValue() {
+ return this.value;
+ }
+
+ public static Measurement getMeasurementNamed(String name,
+ Measurement[] measurements) {
+ for (Measurement measurement : measurements) {
+ if (name.equals(measurement.getName())) {
+ return measurement;
+ }
}
+ return null;
+ }
- public String getName() {
- return this.name;
+ public static void getMeasurementsDescription(Measurement[] measurements,
+ StringBuilder out, int indent) {
+ if (measurements.length > 0) {
+ StringUtils.appendIndented(out, indent, measurements[0].toString());
+ for (int i = 1; i < measurements.length; i++) {
+ StringUtils.appendNewlineIndented(out, indent, measurements[i].toString());
+ }
+
}
+ }
- public double getValue() {
- return this.value;
- }
-
- public static Measurement getMeasurementNamed(String name,
- Measurement[] measurements) {
- for (Measurement measurement : measurements) {
- if (name.equals(measurement.getName())) {
- return measurement;
- }
+ public static Measurement[] averageMeasurements(Measurement[][] toAverage) {
+ List<String> measurementNames = new ArrayList<String>();
+ for (Measurement[] measurements : toAverage) {
+ for (Measurement measurement : measurements) {
+ if (measurementNames.indexOf(measurement.getName()) < 0) {
+ measurementNames.add(measurement.getName());
}
- return null;
+ }
}
-
- public static void getMeasurementsDescription(Measurement[] measurements,
- StringBuilder out, int indent) {
- if (measurements.length > 0) {
- StringUtils.appendIndented(out, indent, measurements[0].toString());
- for (int i = 1; i < measurements.length; i++) {
- StringUtils.appendNewlineIndented(out, indent, measurements[i].toString());
- }
-
- }
+ GaussianEstimator[] estimators = new GaussianEstimator[measurementNames.size()];
+ for (int i = 0; i < estimators.length; i++) {
+ estimators[i] = new GaussianEstimator();
}
-
- public static Measurement[] averageMeasurements(Measurement[][] toAverage) {
- List<String> measurementNames = new ArrayList<String>();
- for (Measurement[] measurements : toAverage) {
- for (Measurement measurement : measurements) {
- if (measurementNames.indexOf(measurement.getName()) < 0) {
- measurementNames.add(measurement.getName());
- }
- }
- }
- GaussianEstimator[] estimators = new GaussianEstimator[measurementNames.size()];
- for (int i = 0; i < estimators.length; i++) {
- estimators[i] = new GaussianEstimator();
- }
- for (Measurement[] measurements : toAverage) {
- for (Measurement measurement : measurements) {
- estimators[measurementNames.indexOf(measurement.getName())].addObservation(measurement.getValue(), 1.0);
- }
- }
- List<Measurement> averagedMeasurements = new ArrayList<Measurement>();
- for (int i = 0; i < measurementNames.size(); i++) {
- String mName = measurementNames.get(i);
- GaussianEstimator mEstimator = estimators[i];
- if (mEstimator.getTotalWeightObserved() > 1.0) {
- averagedMeasurements.add(new Measurement("[avg] " + mName,
- mEstimator.getMean()));
- averagedMeasurements.add(new Measurement("[err] " + mName,
- mEstimator.getStdDev()
- / Math.sqrt(mEstimator.getTotalWeightObserved())));
- }
- }
- return averagedMeasurements.toArray(new Measurement[averagedMeasurements.size()]);
+ for (Measurement[] measurements : toAverage) {
+ for (Measurement measurement : measurements) {
+ estimators[measurementNames.indexOf(measurement.getName())].addObservation(measurement.getValue(), 1.0);
+ }
}
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- sb.append(getName());
- sb.append(" = ");
- sb.append(StringUtils.doubleToString(getValue(), 3));
+ List<Measurement> averagedMeasurements = new ArrayList<Measurement>();
+ for (int i = 0; i < measurementNames.size(); i++) {
+ String mName = measurementNames.get(i);
+ GaussianEstimator mEstimator = estimators[i];
+ if (mEstimator.getTotalWeightObserved() > 1.0) {
+ averagedMeasurements.add(new Measurement("[avg] " + mName,
+ mEstimator.getMean()));
+ averagedMeasurements.add(new Measurement("[err] " + mName,
+ mEstimator.getStdDev()
+ / Math.sqrt(mEstimator.getTotalWeightObserved())));
+ }
}
+ return averagedMeasurements.toArray(new Measurement[averagedMeasurements.size()]);
+ }
+
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ sb.append(getName());
+ sb.append(" = ");
+ sb.append(StringUtils.doubleToString(getValue(), 3));
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/MiscUtils.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/MiscUtils.java
index 4bfc98f..c176369 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/MiscUtils.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/MiscUtils.java
@@ -24,74 +24,74 @@
import java.io.StringWriter;
import java.util.Random;
-
/**
* Class implementing some utility methods.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class MiscUtils {
- public static int chooseRandomIndexBasedOnWeights(double[] weights,
- Random random) {
- double probSum = Utils.sum(weights);
- double val = random.nextDouble() * probSum;
- int index = 0;
- double sum = 0.0;
- while ((sum <= val) && (index < weights.length)) {
- sum += weights[index++];
- }
- return index - 1;
+ public static int chooseRandomIndexBasedOnWeights(double[] weights,
+ Random random) {
+ double probSum = Utils.sum(weights);
+ double val = random.nextDouble() * probSum;
+ int index = 0;
+ double sum = 0.0;
+ while ((sum <= val) && (index < weights.length)) {
+ sum += weights[index++];
+ }
+ return index - 1;
+ }
+
+ public static int poisson(double lambda, Random r) {
+ if (lambda < 100.0) {
+ double product = 1.0;
+ double sum = 1.0;
+ double threshold = r.nextDouble() * Math.exp(lambda);
+ int i = 1;
+ int max = Math.max(100, 10 * (int) Math.ceil(lambda));
+ while ((i < max) && (sum <= threshold)) {
+ product *= (lambda / i);
+ sum += product;
+ i++;
+ }
+ return i - 1;
+ }
+ double x = lambda + Math.sqrt(lambda) * r.nextGaussian();
+ if (x < 0.0) {
+ return 0;
+ }
+ return (int) Math.floor(x);
+ }
+
+ public static String getStackTraceString(Exception ex) {
+ StringWriter stackTraceWriter = new StringWriter();
+ ex.printStackTrace(new PrintWriter(stackTraceWriter));
+ return "*** STACK TRACE ***\n" + stackTraceWriter.toString();
+ }
+
+ /**
+ * Returns index of maximum element in a given array of doubles. First maximum
+ * is returned.
+ *
+ * @param doubles
+ * the array of doubles
+ * @return the index of the maximum element
+ */
+ public static/* @pure@ */int maxIndex(double[] doubles) {
+
+ double maximum = 0;
+ int maxIndex = 0;
+
+ for (int i = 0; i < doubles.length; i++) {
+ if ((i == 0) || (doubles[i] > maximum)) {
+ maxIndex = i;
+ maximum = doubles[i];
+ }
}
- public static int poisson(double lambda, Random r) {
- if (lambda < 100.0) {
- double product = 1.0;
- double sum = 1.0;
- double threshold = r.nextDouble() * Math.exp(lambda);
- int i = 1;
- int max = Math.max(100, 10 * (int) Math.ceil(lambda));
- while ((i < max) && (sum <= threshold)) {
- product *= (lambda / i);
- sum += product;
- i++;
- }
- return i - 1;
- }
- double x = lambda + Math.sqrt(lambda) * r.nextGaussian();
- if (x < 0.0) {
- return 0;
- }
- return (int) Math.floor(x);
- }
-
- public static String getStackTraceString(Exception ex) {
- StringWriter stackTraceWriter = new StringWriter();
- ex.printStackTrace(new PrintWriter(stackTraceWriter));
- return "*** STACK TRACE ***\n" + stackTraceWriter.toString();
- }
-
- /**
- * Returns index of maximum element in a given array of doubles. First
- * maximum is returned.
- *
- * @param doubles the array of doubles
- * @return the index of the maximum element
- */
- public static /*@pure@*/ int maxIndex(double[] doubles) {
-
- double maximum = 0;
- int maxIndex = 0;
-
- for (int i = 0; i < doubles.length; i++) {
- if ((i == 0) || (doubles[i] > maximum)) {
- maxIndex = i;
- maximum = doubles[i];
- }
- }
-
- return maxIndex;
- }
+ return maxIndex;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/ObjectRepository.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/ObjectRepository.java
index d326361..a7a4e35 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/ObjectRepository.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/ObjectRepository.java
@@ -22,11 +22,11 @@
/**
* Interface for object repositories.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public interface ObjectRepository {
- Object getObjectNamed(String string);
+ Object getObjectNamed(String string);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/SerializeUtils.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/SerializeUtils.java
index 99c9626..5030c55 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/SerializeUtils.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/SerializeUtils.java
@@ -37,76 +37,76 @@
/**
* Class implementing some serialize utility methods.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class SerializeUtils {
- public static class ByteCountingOutputStream extends OutputStream {
+ public static class ByteCountingOutputStream extends OutputStream {
- protected int numBytesWritten = 0;
+ protected int numBytesWritten = 0;
- public int getNumBytesWritten() {
- return this.numBytesWritten;
- }
-
- @Override
- public void write(int b) throws IOException {
- this.numBytesWritten++;
- }
-
- @Override
- public void write(byte[] b, int off, int len) throws IOException {
- this.numBytesWritten += len;
- }
-
- @Override
- public void write(byte[] b) throws IOException {
- this.numBytesWritten += b.length;
- }
+ public int getNumBytesWritten() {
+ return this.numBytesWritten;
}
- public static void writeToFile(File file, Serializable obj)
- throws IOException {
- ObjectOutputStream out = new ObjectOutputStream(new GZIPOutputStream(
- new BufferedOutputStream(new FileOutputStream(file))));
- out.writeObject(obj);
- out.flush();
- out.close();
+ @Override
+ public void write(int b) throws IOException {
+ this.numBytesWritten++;
}
- public static Object readFromFile(File file) throws IOException,
- ClassNotFoundException {
- ObjectInputStream in = new ObjectInputStream(new GZIPInputStream(
- new BufferedInputStream(new FileInputStream(file))));
- Object obj = in.readObject();
- in.close();
- return obj;
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ this.numBytesWritten += len;
}
- public static Object copyObject(Serializable obj) throws Exception {
- ByteArrayOutputStream baoStream = new ByteArrayOutputStream();
- ObjectOutputStream out = new ObjectOutputStream(
- new BufferedOutputStream(baoStream));
- out.writeObject(obj);
- out.flush();
- out.close();
- byte[] byteArray = baoStream.toByteArray();
- ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(
- new ByteArrayInputStream(byteArray)));
- Object copy = in.readObject();
- in.close();
- return copy;
+ @Override
+ public void write(byte[] b) throws IOException {
+ this.numBytesWritten += b.length;
}
+ }
- public static int measureObjectByteSize(Serializable obj) throws Exception {
- ByteCountingOutputStream bcoStream = new ByteCountingOutputStream();
- ObjectOutputStream out = new ObjectOutputStream(
- new BufferedOutputStream(bcoStream));
- out.writeObject(obj);
- out.flush();
- out.close();
- return bcoStream.getNumBytesWritten();
- }
+ public static void writeToFile(File file, Serializable obj)
+ throws IOException {
+ ObjectOutputStream out = new ObjectOutputStream(new GZIPOutputStream(
+ new BufferedOutputStream(new FileOutputStream(file))));
+ out.writeObject(obj);
+ out.flush();
+ out.close();
+ }
+
+ public static Object readFromFile(File file) throws IOException,
+ ClassNotFoundException {
+ ObjectInputStream in = new ObjectInputStream(new GZIPInputStream(
+ new BufferedInputStream(new FileInputStream(file))));
+ Object obj = in.readObject();
+ in.close();
+ return obj;
+ }
+
+ public static Object copyObject(Serializable obj) throws Exception {
+ ByteArrayOutputStream baoStream = new ByteArrayOutputStream();
+ ObjectOutputStream out = new ObjectOutputStream(
+ new BufferedOutputStream(baoStream));
+ out.writeObject(obj);
+ out.flush();
+ out.close();
+ byte[] byteArray = baoStream.toByteArray();
+ ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(
+ new ByteArrayInputStream(byteArray)));
+ Object copy = in.readObject();
+ in.close();
+ return copy;
+ }
+
+ public static int measureObjectByteSize(Serializable obj) throws Exception {
+ ByteCountingOutputStream bcoStream = new ByteCountingOutputStream();
+ ObjectOutputStream out = new ObjectOutputStream(
+ new BufferedOutputStream(bcoStream));
+ out.writeObject(obj);
+ out.flush();
+ out.close();
+ return bcoStream.getNumBytesWritten();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Statistics.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Statistics.java
index 65536db..c270de9 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Statistics.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Statistics.java
@@ -23,140 +23,150 @@
public class Statistics {
/** Some constants */
- protected static final double MACHEP = 1.11022302462515654042E-16;
- protected static final double MAXLOG = 7.09782712893383996732E2;
+ protected static final double MACHEP = 1.11022302462515654042E-16;
+ protected static final double MAXLOG = 7.09782712893383996732E2;
protected static final double MINLOG = -7.451332191019412076235E2;
protected static final double MAXGAM = 171.624376956302725;
- protected static final double SQTPI = 2.50662827463100050242E0;
- protected static final double SQRTH = 7.07106781186547524401E-1;
- protected static final double LOGPI = 1.14472988584940017414;
-
- protected static final double big = 4.503599627370496e15;
- protected static final double biginv = 2.22044604925031308085e-16;
+ protected static final double SQTPI = 2.50662827463100050242E0;
+ protected static final double SQRTH = 7.07106781186547524401E-1;
+ protected static final double LOGPI = 1.14472988584940017414;
+
+ protected static final double big = 4.503599627370496e15;
+ protected static final double biginv = 2.22044604925031308085e-16;
/*************************************************
- * COEFFICIENTS FOR METHOD normalInverse() *
+ * COEFFICIENTS FOR METHOD normalInverse() *
*************************************************/
/* approximation for 0 <= |y - 0.5| <= 3/8 */
protected static final double P0[] = {
- -5.99633501014107895267E1,
- 9.80010754185999661536E1,
- -5.66762857469070293439E1,
- 1.39312609387279679503E1,
- -1.23916583867381258016E0,
+ -5.99633501014107895267E1,
+ 9.80010754185999661536E1,
+ -5.66762857469070293439E1,
+ 1.39312609387279679503E1,
+ -1.23916583867381258016E0,
};
protected static final double Q0[] = {
- /* 1.00000000000000000000E0,*/
- 1.95448858338141759834E0,
- 4.67627912898881538453E0,
- 8.63602421390890590575E1,
- -2.25462687854119370527E2,
- 2.00260212380060660359E2,
- -8.20372256168333339912E1,
- 1.59056225126211695515E1,
- -1.18331621121330003142E0,
+ /* 1.00000000000000000000E0, */
+ 1.95448858338141759834E0,
+ 4.67627912898881538453E0,
+ 8.63602421390890590575E1,
+ -2.25462687854119370527E2,
+ 2.00260212380060660359E2,
+ -8.20372256168333339912E1,
+ 1.59056225126211695515E1,
+ -1.18331621121330003142E0,
};
-
- /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8
- * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14.
+
+ /*
+ * Approximation for interval z = sqrt(-2 log y ) between 2 and 8 i.e., y
+ * between exp(-2) = .135 and exp(-32) = 1.27e-14.
*/
protected static final double P1[] = {
- 4.05544892305962419923E0,
- 3.15251094599893866154E1,
- 5.71628192246421288162E1,
- 4.40805073893200834700E1,
- 1.46849561928858024014E1,
- 2.18663306850790267539E0,
- -1.40256079171354495875E-1,
- -3.50424626827848203418E-2,
- -8.57456785154685413611E-4,
+ 4.05544892305962419923E0,
+ 3.15251094599893866154E1,
+ 5.71628192246421288162E1,
+ 4.40805073893200834700E1,
+ 1.46849561928858024014E1,
+ 2.18663306850790267539E0,
+ -1.40256079171354495875E-1,
+ -3.50424626827848203418E-2,
+ -8.57456785154685413611E-4,
};
protected static final double Q1[] = {
- /* 1.00000000000000000000E0,*/
- 1.57799883256466749731E1,
- 4.53907635128879210584E1,
- 4.13172038254672030440E1,
- 1.50425385692907503408E1,
- 2.50464946208309415979E0,
- -1.42182922854787788574E-1,
- -3.80806407691578277194E-2,
- -9.33259480895457427372E-4,
+ /* 1.00000000000000000000E0, */
+ 1.57799883256466749731E1,
+ 4.53907635128879210584E1,
+ 4.13172038254672030440E1,
+ 1.50425385692907503408E1,
+ 2.50464946208309415979E0,
+ -1.42182922854787788574E-1,
+ -3.80806407691578277194E-2,
+ -9.33259480895457427372E-4,
};
-
- /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64
- * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
+
+ /*
+ * Approximation for interval z = sqrt(-2 log y ) between 8 and 64 i.e., y
+ * between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
*/
- protected static final double P2[] = {
- 3.23774891776946035970E0,
- 6.91522889068984211695E0,
- 3.93881025292474443415E0,
- 1.33303460815807542389E0,
- 2.01485389549179081538E-1,
- 1.23716634817820021358E-2,
- 3.01581553508235416007E-4,
- 2.65806974686737550832E-6,
- 6.23974539184983293730E-9,
+ protected static final double P2[] = {
+ 3.23774891776946035970E0,
+ 6.91522889068984211695E0,
+ 3.93881025292474443415E0,
+ 1.33303460815807542389E0,
+ 2.01485389549179081538E-1,
+ 1.23716634817820021358E-2,
+ 3.01581553508235416007E-4,
+ 2.65806974686737550832E-6,
+ 6.23974539184983293730E-9,
};
- protected static final double Q2[] = {
- /* 1.00000000000000000000E0,*/
- 6.02427039364742014255E0,
- 3.67983563856160859403E0,
- 1.37702099489081330271E0,
- 2.16236993594496635890E-1,
- 1.34204006088543189037E-2,
- 3.28014464682127739104E-4,
- 2.89247864745380683936E-6,
- 6.79019408009981274425E-9,
+ protected static final double Q2[] = {
+ /* 1.00000000000000000000E0, */
+ 6.02427039364742014255E0,
+ 3.67983563856160859403E0,
+ 1.37702099489081330271E0,
+ 2.16236993594496635890E-1,
+ 1.34204006088543189037E-2,
+ 3.28014464682127739104E-4,
+ 2.89247864745380683936E-6,
+ 6.79019408009981274425E-9,
};
-
+
/**
- * Computes standard error for observed values of a binomial
- * random variable.
- *
- * @param p the probability of success
- * @param n the size of the sample
+ * Computes standard error for observed values of a binomial random variable.
+ *
+ * @param p
+ * the probability of success
+ * @param n
+ * the size of the sample
* @return the standard error
*/
public static double binomialStandardError(double p, int n) {
-
+
if (n == 0) {
- return 0;
+ return 0;
}
- return Math.sqrt((p*(1-p))/(double) n);
+ return Math.sqrt((p * (1 - p)) / (double) n);
}
-
+
/**
- * Returns chi-squared probability for given value and degrees
- * of freedom. (The probability that the chi-squared variate
- * will be greater than x for the given degrees of freedom.)
- *
- * @param x the value
- * @param v the number of degrees of freedom
+ * Returns chi-squared probability for given value and degrees of freedom.
+ * (The probability that the chi-squared variate will be greater than x for
+ * the given degrees of freedom.)
+ *
+ * @param x
+ * the value
+ * @param v
+ * the number of degrees of freedom
* @return the chi-squared probability
*/
- public static double chiSquaredProbability(double x, double v) {
+ public static double chiSquaredProbability(double x, double v) {
- if( x < 0.0 || v < 1.0 ) return 0.0;
- return incompleteGammaComplement( v/2.0, x/2.0 );
+ if (x < 0.0 || v < 1.0)
+ return 0.0;
+ return incompleteGammaComplement(v / 2.0, x / 2.0);
}
/**
* Computes probability of F-ratio.
- *
- * @param F the F-ratio
- * @param df1 the first number of degrees of freedom
- * @param df2 the second number of degrees of freedom
+ *
+ * @param F
+ * the F-ratio
+ * @param df1
+ * the first number of degrees of freedom
+ * @param df2
+ * the second number of degrees of freedom
* @return the probability of the F-ratio.
*/
public static double FProbability(double F, int df1, int df2) {
-
- return incompleteBeta( df2/2.0, df1/2.0, df2/(df2+df1*F) );
+
+ return incompleteBeta(df2 / 2.0, df1 / 2.0, df2 / (df2 + df1 * F));
}
/**
- * Returns the area under the Normal (Gaussian) probability density
- * function, integrated from minus infinity to <tt>x</tt>
- * (assumes mean is zero, variance is one).
+ * Returns the area under the Normal (Gaussian) probability density function,
+ * integrated from minus infinity to <tt>x</tt> (assumes mean is zero,
+ * variance is one).
+ *
* <pre>
* x
* -
@@ -165,176 +175,188 @@
* sqrt(2pi) | |
* -
* -inf.
- *
+ *
* = ( 1 + erf(z) ) / 2
* = erfc(z) / 2
* </pre>
- * where <tt>z = x/sqrt(2)</tt>.
- * Computation is via the functions <tt>errorFunction</tt> and <tt>errorFunctionComplement</tt>.
- *
- * @param a the z-value
+ *
+ * where <tt>z = x/sqrt(2)</tt>. Computation is via the functions
+ * <tt>errorFunction</tt> and <tt>errorFunctionComplement</tt>.
+ *
+ * @param a
+ * the z-value
* @return the probability of the z value according to the normal pdf
*/
- public static double normalProbability(double a) {
+ public static double normalProbability(double a) {
double x, y, z;
-
+
x = a * SQRTH;
z = Math.abs(x);
-
- if( z < SQRTH ) y = 0.5 + 0.5 * errorFunction(x);
+
+ if (z < SQRTH)
+ y = 0.5 + 0.5 * errorFunction(x);
else {
y = 0.5 * errorFunctionComplemented(z);
- if( x > 0 ) y = 1.0 - y;
- }
+ if (x > 0)
+ y = 1.0 - y;
+ }
return y;
}
/**
- * Returns the value, <tt>x</tt>, for which the area under the
- * Normal (Gaussian) probability density function (integrated from
- * minus infinity to <tt>x</tt>) is equal to the argument <tt>y</tt>
- * (assumes mean is zero, variance is one).
+ * Returns the value, <tt>x</tt>, for which the area under the Normal
+ * (Gaussian) probability density function (integrated from minus infinity to
+ * <tt>x</tt>) is equal to the argument <tt>y</tt> (assumes mean is zero,
+ * variance is one).
* <p>
* For small arguments <tt>0 < y < exp(-2)</tt>, the program computes
- * <tt>z = sqrt( -2.0 * log(y) )</tt>; then the approximation is
- * <tt>x = z - log(z)/z - (1/z) P(1/z) / Q(1/z)</tt>.
- * There are two rational functions P/Q, one for <tt>0 < y < exp(-32)</tt>
- * and the other for <tt>y</tt> up to <tt>exp(-2)</tt>.
- * For larger arguments,
- * <tt>w = y - 0.5</tt>, and <tt>x/sqrt(2pi) = w + w**3 R(w**2)/S(w**2))</tt>.
- *
- * @param y0 the area under the normal pdf
+ * <tt>z = sqrt( -2.0 * log(y) )</tt>; then the approximation is
+ * <tt>x = z - log(z)/z - (1/z) P(1/z) / Q(1/z)</tt>. There are two rational
+ * functions P/Q, one for <tt>0 < y < exp(-32)</tt> and the other for
+ * <tt>y</tt> up to <tt>exp(-2)</tt>. For larger arguments,
+ * <tt>w = y - 0.5</tt>, and <tt>x/sqrt(2pi) = w + w**3 R(w**2)/S(w**2))</tt>.
+ *
+ * @param y0
+ * the area under the normal pdf
* @return the z-value
*/
- public static double normalInverse(double y0) {
+ public static double normalInverse(double y0) {
double x, y, z, y2, x0, x1;
int code;
- final double s2pi = Math.sqrt(2.0*Math.PI);
+ final double s2pi = Math.sqrt(2.0 * Math.PI);
- if( y0 <= 0.0 ) throw new IllegalArgumentException();
- if( y0 >= 1.0 ) throw new IllegalArgumentException();
+ if (y0 <= 0.0)
+ throw new IllegalArgumentException();
+ if (y0 >= 1.0)
+ throw new IllegalArgumentException();
code = 1;
y = y0;
- if( y > (1.0 - 0.13533528323661269189) ) { /* 0.135... = exp(-2) */
+ if (y > (1.0 - 0.13533528323661269189)) { /* 0.135... = exp(-2) */
y = 1.0 - y;
code = 0;
}
- if( y > 0.13533528323661269189 ) {
+ if (y > 0.13533528323661269189) {
y = y - 0.5;
y2 = y * y;
- x = y + y * (y2 * polevl( y2, P0, 4)/p1evl( y2, Q0, 8 ));
- x = x * s2pi;
- return(x);
+ x = y + y * (y2 * polevl(y2, P0, 4) / p1evl(y2, Q0, 8));
+ x = x * s2pi;
+ return (x);
}
- x = Math.sqrt( -2.0 * Math.log(y) );
- x0 = x - Math.log(x)/x;
+ x = Math.sqrt(-2.0 * Math.log(y));
+ x0 = x - Math.log(x) / x;
- z = 1.0/x;
- if( x < 8.0 ) /* y > exp(-32) = 1.2664165549e-14 */
- x1 = z * polevl( z, P1, 8 )/p1evl( z, Q1, 8 );
+ z = 1.0 / x;
+ if (x < 8.0) /* y > exp(-32) = 1.2664165549e-14 */
+ x1 = z * polevl(z, P1, 8) / p1evl(z, Q1, 8);
else
- x1 = z * polevl( z, P2, 8 )/p1evl( z, Q2, 8 );
+ x1 = z * polevl(z, P2, 8) / p1evl(z, Q2, 8);
x = x0 - x1;
- if( code != 0 )
+ if (code != 0)
x = -x;
- return( x );
+ return (x);
}
/**
* Returns natural logarithm of gamma function.
- *
- * @param x the value
+ *
+ * @param x
+ * the value
* @return natural logarithm of gamma function
*/
public static double lnGamma(double x) {
double p, q, w, z;
-
+
double A[] = {
- 8.11614167470508450300E-4,
- -5.95061904284301438324E-4,
- 7.93650340457716943945E-4,
- -2.77777777730099687205E-3,
- 8.33333333333331927722E-2
+ 8.11614167470508450300E-4,
+ -5.95061904284301438324E-4,
+ 7.93650340457716943945E-4,
+ -2.77777777730099687205E-3,
+ 8.33333333333331927722E-2
};
double B[] = {
- -1.37825152569120859100E3,
- -3.88016315134637840924E4,
- -3.31612992738871184744E5,
- -1.16237097492762307383E6,
- -1.72173700820839662146E6,
- -8.53555664245765465627E5
+ -1.37825152569120859100E3,
+ -3.88016315134637840924E4,
+ -3.31612992738871184744E5,
+ -1.16237097492762307383E6,
+ -1.72173700820839662146E6,
+ -8.53555664245765465627E5
};
double C[] = {
- /* 1.00000000000000000000E0, */
- -3.51815701436523470549E2,
- -1.70642106651881159223E4,
- -2.20528590553854454839E5,
- -1.13933444367982507207E6,
- -2.53252307177582951285E6,
- -2.01889141433532773231E6
+ /* 1.00000000000000000000E0, */
+ -3.51815701436523470549E2,
+ -1.70642106651881159223E4,
+ -2.20528590553854454839E5,
+ -1.13933444367982507207E6,
+ -2.53252307177582951285E6,
+ -2.01889141433532773231E6
};
-
- if( x < -34.0 ) {
+
+ if (x < -34.0) {
q = -x;
w = lnGamma(q);
p = Math.floor(q);
- if( p == q ) throw new ArithmeticException("lnGamma: Overflow");
+ if (p == q)
+ throw new ArithmeticException("lnGamma: Overflow");
z = q - p;
- if( z > 0.5 ) {
- p += 1.0;
- z = p - q;
+ if (z > 0.5) {
+ p += 1.0;
+ z = p - q;
}
- z = q * Math.sin( Math.PI * z );
- if( z == 0.0 ) throw new
- ArithmeticException("lnGamma: Overflow");
- z = LOGPI - Math.log( z ) - w;
+ z = q * Math.sin(Math.PI * z);
+ if (z == 0.0)
+ throw new ArithmeticException("lnGamma: Overflow");
+ z = LOGPI - Math.log(z) - w;
return z;
}
-
- if( x < 13.0 ) {
+
+ if (x < 13.0) {
z = 1.0;
- while( x >= 3.0 ) {
- x -= 1.0;
- z *= x;
+ while (x >= 3.0) {
+ x -= 1.0;
+ z *= x;
}
- while( x < 2.0 ) {
- if( x == 0.0 ) throw new
- ArithmeticException("lnGamma: Overflow");
- z /= x;
- x += 1.0;
+ while (x < 2.0) {
+ if (x == 0.0)
+ throw new ArithmeticException("lnGamma: Overflow");
+ z /= x;
+ x += 1.0;
}
- if( z < 0.0 ) z = -z;
- if( x == 2.0 ) return Math.log(z);
+ if (z < 0.0)
+ z = -z;
+ if (x == 2.0)
+ return Math.log(z);
x -= 2.0;
- p = x * polevl( x, B, 5 ) / p1evl( x, C, 6);
- return( Math.log(z) + p );
+ p = x * polevl(x, B, 5) / p1evl(x, C, 6);
+ return (Math.log(z) + p);
}
-
- if( x > 2.556348e305 ) throw new ArithmeticException("lnGamma: Overflow");
-
- q = ( x - 0.5 ) * Math.log(x) - x + 0.91893853320467274178;
-
- if( x > 1.0e8 ) return( q );
-
- p = 1.0/(x*x);
- if( x >= 1000.0 )
- q += (( 7.9365079365079365079365e-4 * p
- - 2.7777777777777777777778e-3) *p
- + 0.0833333333333333333333) / x;
+
+ if (x > 2.556348e305)
+ throw new ArithmeticException("lnGamma: Overflow");
+
+ q = (x - 0.5) * Math.log(x) - x + 0.91893853320467274178;
+
+ if (x > 1.0e8)
+ return (q);
+
+ p = 1.0 / (x * x);
+ if (x >= 1000.0)
+ q += ((7.9365079365079365079365e-4 * p
+ - 2.7777777777777777777778e-3) * p
+ + 0.0833333333333333333333) / x;
else
- q += polevl( p, A, 4 ) / x;
+ q += polevl(p, A, 4) / x;
return q;
}
/**
- * Returns the error function of the normal distribution.
- * The integral is
+ * Returns the error function of the normal distribution. The integral is
+ *
* <pre>
* x
* -
@@ -344,47 +366,51 @@
* -
* 0
* </pre>
- * <b>Implementation:</b>
- * For <tt>0 <= |x| < 1, erf(x) = x * P4(x**2)/Q5(x**2)</tt>; otherwise
+ *
+ * <b>Implementation:</b> For
+ * <tt>0 <= |x| < 1, erf(x) = x * P4(x**2)/Q5(x**2)</tt>; otherwise
* <tt>erf(x) = 1 - erfc(x)</tt>.
* <p>
- * Code adapted from the <A HREF="http://www.sci.usq.edu.au/staff/leighb/graph/Top.html">
- * Java 2D Graph Package 2.4</A>,
- * which in turn is a port from the
- * <A HREF="http://people.ne.mediaone.net/moshier/index.html#Cephes">Cephes 2.2</A>
- * Math Library (C).
- *
- * @param a the argument to the function.
+ * Code adapted from the <A
+ * HREF="http://www.sci.usq.edu.au/staff/leighb/graph/Top.html"> Java 2D Graph
+ * Package 2.4</A>, which in turn is a port from the <A
+ * HREF="http://people.ne.mediaone.net/moshier/index.html#Cephes">Cephes
+ * 2.2</A> Math Library (C).
+ *
+ * @param a
+ * the argument to the function.
*/
- public static double errorFunction(double x) {
+ public static double errorFunction(double x) {
double y, z;
final double T[] = {
- 9.60497373987051638749E0,
- 9.00260197203842689217E1,
- 2.23200534594684319226E3,
- 7.00332514112805075473E3,
- 5.55923013010394962768E4
+ 9.60497373987051638749E0,
+ 9.00260197203842689217E1,
+ 2.23200534594684319226E3,
+ 7.00332514112805075473E3,
+ 5.55923013010394962768E4
};
final double U[] = {
- //1.00000000000000000000E0,
- 3.35617141647503099647E1,
- 5.21357949780152679795E2,
- 4.59432382970980127987E3,
- 2.26290000613890934246E4,
- 4.92673942608635921086E4
+ // 1.00000000000000000000E0,
+ 3.35617141647503099647E1,
+ 5.21357949780152679795E2,
+ 4.59432382970980127987E3,
+ 2.26290000613890934246E4,
+ 4.92673942608635921086E4
};
-
- if( Math.abs(x) > 1.0 ) return( 1.0 - errorFunctionComplemented(x) );
+
+ if (Math.abs(x) > 1.0)
+ return (1.0 - errorFunctionComplemented(x));
z = x * x;
- y = x * polevl( z, T, 4 ) / p1evl( z, U, 5 );
+ y = x * polevl(z, T, 4) / p1evl(z, U, 5);
return y;
}
/**
* Returns the complementary Error function of the normal distribution.
+ *
* <pre>
* 1 - erf(x) =
- *
+ *
* inf.
* -
* 2 | | 2
@@ -393,174 +419,201 @@
* -
* x
* </pre>
- * <b>Implementation:</b>
- * For small x, <tt>erfc(x) = 1 - erf(x)</tt>; otherwise rational
- * approximations are computed.
+ *
+ * <b>Implementation:</b> For small x, <tt>erfc(x) = 1 - erf(x)</tt>;
+ * otherwise rational approximations are computed.
* <p>
- * Code adapted from the <A HREF="http://www.sci.usq.edu.au/staff/leighb/graph/Top.html">
- * Java 2D Graph Package 2.4</A>,
- * which in turn is a port from the
- * <A HREF="http://people.ne.mediaone.net/moshier/index.html#Cephes">Cephes 2.2</A>
- * Math Library (C).
- *
- * @param a the argument to the function.
+ * Code adapted from the <A
+ * HREF="http://www.sci.usq.edu.au/staff/leighb/graph/Top.html"> Java 2D Graph
+ * Package 2.4</A>, which in turn is a port from the <A
+ * HREF="http://people.ne.mediaone.net/moshier/index.html#Cephes">Cephes
+ * 2.2</A> Math Library (C).
+ *
+ * @param a
+ * the argument to the function.
*/
- public static double errorFunctionComplemented(double a) {
- double x,y,z,p,q;
-
+ public static double errorFunctionComplemented(double a) {
+ double x, y, z, p, q;
+
double P[] = {
- 2.46196981473530512524E-10,
- 5.64189564831068821977E-1,
- 7.46321056442269912687E0,
- 4.86371970985681366614E1,
- 1.96520832956077098242E2,
- 5.26445194995477358631E2,
- 9.34528527171957607540E2,
- 1.02755188689515710272E3,
- 5.57535335369399327526E2
+ 2.46196981473530512524E-10,
+ 5.64189564831068821977E-1,
+ 7.46321056442269912687E0,
+ 4.86371970985681366614E1,
+ 1.96520832956077098242E2,
+ 5.26445194995477358631E2,
+ 9.34528527171957607540E2,
+ 1.02755188689515710272E3,
+ 5.57535335369399327526E2
};
double Q[] = {
- //1.0
- 1.32281951154744992508E1,
- 8.67072140885989742329E1,
- 3.54937778887819891062E2,
- 9.75708501743205489753E2,
- 1.82390916687909736289E3,
- 2.24633760818710981792E3,
- 1.65666309194161350182E3,
- 5.57535340817727675546E2
+ // 1.0
+ 1.32281951154744992508E1,
+ 8.67072140885989742329E1,
+ 3.54937778887819891062E2,
+ 9.75708501743205489753E2,
+ 1.82390916687909736289E3,
+ 2.24633760818710981792E3,
+ 1.65666309194161350182E3,
+ 5.57535340817727675546E2
};
-
+
double R[] = {
- 5.64189583547755073984E-1,
- 1.27536670759978104416E0,
- 5.01905042251180477414E0,
- 6.16021097993053585195E0,
- 7.40974269950448939160E0,
- 2.97886665372100240670E0
+ 5.64189583547755073984E-1,
+ 1.27536670759978104416E0,
+ 5.01905042251180477414E0,
+ 6.16021097993053585195E0,
+ 7.40974269950448939160E0,
+ 2.97886665372100240670E0
};
double S[] = {
- //1.00000000000000000000E0,
- 2.26052863220117276590E0,
- 9.39603524938001434673E0,
- 1.20489539808096656605E1,
- 1.70814450747565897222E1,
- 9.60896809063285878198E0,
- 3.36907645100081516050E0
+ // 1.00000000000000000000E0,
+ 2.26052863220117276590E0,
+ 9.39603524938001434673E0,
+ 1.20489539808096656605E1,
+ 1.70814450747565897222E1,
+ 9.60896809063285878198E0,
+ 3.36907645100081516050E0
};
-
- if( a < 0.0 ) x = -a;
- else x = a;
-
- if( x < 1.0 ) return 1.0 - errorFunction(a);
-
+
+ if (a < 0.0)
+ x = -a;
+ else
+ x = a;
+
+ if (x < 1.0)
+ return 1.0 - errorFunction(a);
+
z = -a * a;
-
- if( z < -MAXLOG ) {
- if( a < 0 ) return( 2.0 );
- else return( 0.0 );
+
+ if (z < -MAXLOG) {
+ if (a < 0)
+ return (2.0);
+ else
+ return (0.0);
}
-
+
z = Math.exp(z);
-
- if( x < 8.0 ) {
- p = polevl( x, P, 8 );
- q = p1evl( x, Q, 8 );
+
+ if (x < 8.0) {
+ p = polevl(x, P, 8);
+ q = p1evl(x, Q, 8);
} else {
- p = polevl( x, R, 5 );
- q = p1evl( x, S, 6 );
+ p = polevl(x, R, 5);
+ q = p1evl(x, S, 6);
}
-
- y = (z * p)/q;
-
- if( a < 0 ) y = 2.0 - y;
-
- if( y == 0.0 ) {
- if( a < 0 ) return 2.0;
- else return( 0.0 );
+
+ y = (z * p) / q;
+
+ if (a < 0)
+ y = 2.0 - y;
+
+ if (y == 0.0) {
+ if (a < 0)
+ return 2.0;
+ else
+ return (0.0);
}
return y;
}
-
+
/**
* Evaluates the given polynomial of degree <tt>N</tt> at <tt>x</tt>.
- * Evaluates polynomial when coefficient of N is 1.0.
- * Otherwise same as <tt>polevl()</tt>.
+ * Evaluates polynomial when coefficient of N is 1.0. Otherwise same as
+ * <tt>polevl()</tt>.
+ *
* <pre>
* 2 N
* y = C + C x + C x +...+ C x
* 0 1 2 N
- *
+ *
* Coefficients are stored in reverse order:
- *
+ *
* coef[0] = C , ..., coef[N] = C .
* N 0
* </pre>
+ *
* The function <tt>p1evl()</tt> assumes that <tt>coef[N] = 1.0</tt> and is
- * omitted from the array. Its calling arguments are
- * otherwise the same as <tt>polevl()</tt>.
+ * omitted from the array. Its calling arguments are otherwise the same as
+ * <tt>polevl()</tt>.
* <p>
* In the interest of speed, there are no checks for out of bounds arithmetic.
- *
- * @param x argument to the polynomial.
- * @param coef the coefficients of the polynomial.
- * @param N the degree of the polynomial.
+ *
+ * @param x
+ * argument to the polynomial.
+ * @param coef
+ * the coefficients of the polynomial.
+ * @param N
+ * the degree of the polynomial.
*/
- public static double p1evl( double x, double coef[], int N ) {
-
+ public static double p1evl(double x, double coef[], int N) {
+
double ans;
ans = x + coef[0];
-
- for(int i=1; i<N; i++) ans = ans*x+coef[i];
-
+
+ for (int i = 1; i < N; i++)
+ ans = ans * x + coef[i];
+
return ans;
}
/**
* Evaluates the given polynomial of degree <tt>N</tt> at <tt>x</tt>.
+ *
* <pre>
* 2 N
* y = C + C x + C x +...+ C x
* 0 1 2 N
- *
+ *
* Coefficients are stored in reverse order:
- *
+ *
* coef[0] = C , ..., coef[N] = C .
* N 0
* </pre>
+ *
* In the interest of speed, there are no checks for out of bounds arithmetic.
- *
- * @param x argument to the polynomial.
- * @param coef the coefficients of the polynomial.
- * @param N the degree of the polynomial.
+ *
+ * @param x
+ * argument to the polynomial.
+ * @param coef
+ * the coefficients of the polynomial.
+ * @param N
+ * the degree of the polynomial.
*/
- public static double polevl( double x, double coef[], int N ) {
+ public static double polevl(double x, double coef[], int N) {
double ans;
ans = coef[0];
-
- for(int i=1; i<=N; i++) ans = ans*x+coef[i];
-
+
+ for (int i = 1; i <= N; i++)
+ ans = ans * x + coef[i];
+
return ans;
}
/**
* Returns the Incomplete Gamma function.
- * @param a the parameter of the gamma distribution.
- * @param x the integration end point.
+ *
+ * @param a
+ * the parameter of the gamma distribution.
+ * @param x
+ * the integration end point.
*/
- public static double incompleteGamma(double a, double x)
- {
-
- double ans, ax, c, r;
-
- if( x <= 0 || a <= 0 ) return 0.0;
-
- if( x > 1.0 && x > a ) return 1.0 - incompleteGammaComplement(a,x);
+ public static double incompleteGamma(double a, double x)
+ {
- /* Compute x**a * exp(-x) / gamma(a) */
+ double ans, ax, c, r;
+
+ if (x <= 0 || a <= 0)
+ return 0.0;
+
+ if (x > 1.0 && x > a)
+ return 1.0 - incompleteGammaComplement(a, x);
+
+ /* Compute x**a * exp(-x) / gamma(a) */
ax = a * Math.log(x) - x - lnGamma(a);
- if( ax < -MAXLOG ) return( 0.0 );
+ if (ax < -MAXLOG)
+ return (0.0);
ax = Math.exp(ax);
@@ -571,33 +624,38 @@
do {
r += 1.0;
- c *= x/r;
+ c *= x / r;
ans += c;
- }
- while( c/ans > MACHEP );
-
- return( ans * ax/a );
+ } while (c / ans > MACHEP);
+
+ return (ans * ax / a);
}
/**
* Returns the Complemented Incomplete Gamma function.
- * @param a the parameter of the gamma distribution.
- * @param x the integration start point.
+ *
+ * @param a
+ * the parameter of the gamma distribution.
+ * @param x
+ * the integration start point.
*/
- public static double incompleteGammaComplement( double a, double x ) {
+ public static double incompleteGammaComplement(double a, double x) {
double ans, ax, c, yc, r, t, y, z;
double pk, pkm1, pkm2, qk, qkm1, qkm2;
- if( x <= 0 || a <= 0 ) return 1.0;
-
- if( x < 1.0 || x < a ) return 1.0 - incompleteGamma(a,x);
-
+ if (x <= 0 || a <= 0)
+ return 1.0;
+
+ if (x < 1.0 || x < a)
+ return 1.0 - incompleteGamma(a, x);
+
ax = a * Math.log(x) - x - lnGamma(a);
- if( ax < -MAXLOG ) return 0.0;
-
+ if (ax < -MAXLOG)
+ return 0.0;
+
ax = Math.exp(ax);
-
+
/* continued fraction */
y = 1.0 - a;
z = x + y + 1.0;
@@ -606,34 +664,34 @@
qkm2 = x;
pkm1 = x + 1.0;
qkm1 = z * x;
- ans = pkm1/qkm1;
-
+ ans = pkm1 / qkm1;
+
do {
c += 1.0;
y += 1.0;
z += 2.0;
yc = y * c;
- pk = pkm1 * z - pkm2 * yc;
- qk = qkm1 * z - qkm2 * yc;
- if( qk != 0 ) {
- r = pk/qk;
- t = Math.abs( (ans - r)/r );
- ans = r;
+ pk = pkm1 * z - pkm2 * yc;
+ qk = qkm1 * z - qkm2 * yc;
+ if (qk != 0) {
+ r = pk / qk;
+ t = Math.abs((ans - r) / r);
+ ans = r;
} else
- t = 1.0;
+ t = 1.0;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
- if( Math.abs(pk) > big ) {
- pkm2 *= biginv;
+ if (Math.abs(pk) > big) {
+ pkm2 *= biginv;
pkm1 *= biginv;
- qkm2 *= biginv;
- qkm1 *= biginv;
+ qkm2 *= biginv;
+ qkm1 *= biginv;
}
- } while( t > MACHEP );
-
+ } while (t > MACHEP);
+
return ans * ax;
}
@@ -643,23 +701,23 @@
public static double gamma(double x) {
double P[] = {
- 1.60119522476751861407E-4,
- 1.19135147006586384913E-3,
- 1.04213797561761569935E-2,
- 4.76367800457137231464E-2,
- 2.07448227648435975150E-1,
- 4.94214826801497100753E-1,
- 9.99999999999999996796E-1
+ 1.60119522476751861407E-4,
+ 1.19135147006586384913E-3,
+ 1.04213797561761569935E-2,
+ 4.76367800457137231464E-2,
+ 2.07448227648435975150E-1,
+ 4.94214826801497100753E-1,
+ 9.99999999999999996796E-1
};
double Q[] = {
- -2.31581873324120129819E-5,
- 5.39605580493303397842E-4,
- -4.45641913851797240494E-3,
- 1.18139785222060435552E-2,
- 3.58236398605498653373E-2,
- -2.34591795718243348568E-1,
- 7.14304917030273074085E-2,
- 1.00000000000000000320E0
+ -2.31581873324120129819E-5,
+ 5.39605580493303397842E-4,
+ -4.45641913851797240494E-3,
+ 1.18139785222060435552E-2,
+ 3.58236398605498653373E-2,
+ -2.34591795718243348568E-1,
+ 7.14304917030273074085E-2,
+ 1.00000000000000000320E0
};
double p, z;
@@ -667,89 +725,90 @@
double q = Math.abs(x);
- if( q > 33.0 ) {
- if( x < 0.0 ) {
- p = Math.floor(q);
- if( p == q ) throw new ArithmeticException("gamma: overflow");
- i = (int)p;
- z = q - p;
- if( z > 0.5 ) {
- p += 1.0;
- z = q - p;
- }
- z = q * Math.sin( Math.PI * z );
- if( z == 0.0 ) throw new ArithmeticException("gamma: overflow");
- z = Math.abs(z);
- z = Math.PI/(z * stirlingFormula(q) );
+ if (q > 33.0) {
+ if (x < 0.0) {
+ p = Math.floor(q);
+ if (p == q)
+ throw new ArithmeticException("gamma: overflow");
+ i = (int) p;
+ z = q - p;
+ if (z > 0.5) {
+ p += 1.0;
+ z = q - p;
+ }
+ z = q * Math.sin(Math.PI * z);
+ if (z == 0.0)
+ throw new ArithmeticException("gamma: overflow");
+ z = Math.abs(z);
+ z = Math.PI / (z * stirlingFormula(q));
- return -z;
+ return -z;
} else {
- return stirlingFormula(x);
+ return stirlingFormula(x);
}
}
z = 1.0;
- while( x >= 3.0 ) {
+ while (x >= 3.0) {
x -= 1.0;
z *= x;
}
- while( x < 0.0 ) {
- if( x == 0.0 ) {
- throw new ArithmeticException("gamma: singular");
- } else
- if( x > -1.E-9 ) {
- return( z/((1.0 + 0.5772156649015329 * x) * x) );
- }
+ while (x < 0.0) {
+ if (x == 0.0) {
+ throw new ArithmeticException("gamma: singular");
+ } else if (x > -1.E-9) {
+ return (z / ((1.0 + 0.5772156649015329 * x) * x));
+ }
z /= x;
x += 1.0;
}
- while( x < 2.0 ) {
- if( x == 0.0 ) {
- throw new ArithmeticException("gamma: singular");
- } else
- if( x < 1.e-9 ) {
- return( z/((1.0 + 0.5772156649015329 * x) * x) );
- }
+ while (x < 2.0) {
+ if (x == 0.0) {
+ throw new ArithmeticException("gamma: singular");
+ } else if (x < 1.e-9) {
+ return (z / ((1.0 + 0.5772156649015329 * x) * x));
+ }
z /= x;
x += 1.0;
}
- if( (x == 2.0) || (x == 3.0) ) return z;
+ if ((x == 2.0) || (x == 3.0))
+ return z;
x -= 2.0;
- p = polevl( x, P, 6 );
- q = polevl( x, Q, 7 );
- return z * p / q;
+ p = polevl(x, P, 6);
+ q = polevl(x, Q, 7);
+ return z * p / q;
}
/**
- * Returns the Gamma function computed by Stirling's formula.
- * The polynomial STIR is valid for 33 <= x <= 172.
+ * Returns the Gamma function computed by Stirling's formula. The polynomial
+ * STIR is valid for 33 <= x <= 172.
*/
public static double stirlingFormula(double x) {
double STIR[] = {
- 7.87311395793093628397E-4,
- -2.29549961613378126380E-4,
- -2.68132617805781232825E-3,
- 3.47222221605458667310E-3,
- 8.33333333333482257126E-2,
+ 7.87311395793093628397E-4,
+ -2.29549961613378126380E-4,
+ -2.68132617805781232825E-3,
+ 3.47222221605458667310E-3,
+ 8.33333333333482257126E-2,
};
double MAXSTIR = 143.01608;
- double w = 1.0/x;
- double y = Math.exp(x);
+ double w = 1.0 / x;
+ double y = Math.exp(x);
- w = 1.0 + w * polevl( w, STIR, 4 );
+ w = 1.0 + w * polevl(w, STIR, 4);
- if( x > MAXSTIR ) {
+ if (x > MAXSTIR) {
/* Avoid overflow in Math.pow() */
- double v = Math.pow( x, 0.5 * x - 0.25 );
+ double v = Math.pow(x, 0.5 * x - 0.25);
y = v * (v / y);
} else {
- y = Math.pow( x, x - 0.5 ) / y;
+ y = Math.pow(x, x - 0.5) / y;
}
y = SQTPI * y * w;
return y;
@@ -757,27 +816,32 @@
/**
* Returns the Incomplete Beta Function evaluated from zero to <tt>xx</tt>.
- *
- * @param aa the alpha parameter of the beta distribution.
- * @param bb the beta parameter of the beta distribution.
- * @param xx the integration end point.
+ *
+ * @param aa
+ * the alpha parameter of the beta distribution.
+ * @param bb
+ * the beta parameter of the beta distribution.
+ * @param xx
+ * the integration end point.
*/
- public static double incompleteBeta( double aa, double bb, double xx ) {
+ public static double incompleteBeta(double aa, double bb, double xx) {
double a, b, t, x, xc, w, y;
boolean flag;
- if( aa <= 0.0 || bb <= 0.0 ) throw new
- ArithmeticException("ibeta: Domain error!");
+ if (aa <= 0.0 || bb <= 0.0)
+ throw new ArithmeticException("ibeta: Domain error!");
- if( (xx <= 0.0) || ( xx >= 1.0) ) {
- if( xx == 0.0 ) return 0.0;
- if( xx == 1.0 ) return 1.0;
+ if ((xx <= 0.0) || (xx >= 1.0)) {
+ if (xx == 0.0)
+ return 0.0;
+ if (xx == 1.0)
+ return 1.0;
throw new ArithmeticException("ibeta: Domain error!");
}
flag = false;
- if( (bb * xx) <= 1.0 && xx <= 0.95) {
+ if ((bb * xx) <= 1.0 && xx <= 0.95) {
t = powerSeries(aa, bb, xx);
return t;
}
@@ -785,7 +849,7 @@
w = 1.0 - xx;
/* Reverse a and b if x is greater than the mean. */
- if( xx > (aa/(aa+bb)) ) {
+ if (xx > (aa / (aa + bb))) {
flag = true;
a = bb;
b = aa;
@@ -798,57 +862,63 @@
x = xx;
}
- if( flag && (b * x) <= 1.0 && x <= 0.95) {
+ if (flag && (b * x) <= 1.0 && x <= 0.95) {
t = powerSeries(a, b, x);
- if( t <= MACHEP ) t = 1.0 - MACHEP;
- else t = 1.0 - t;
+ if (t <= MACHEP)
+ t = 1.0 - MACHEP;
+ else
+ t = 1.0 - t;
return t;
}
/* Choose expansion for better convergence. */
- y = x * (a+b-2.0) - (a-1.0);
- if( y < 0.0 )
- w = incompleteBetaFraction1( a, b, x );
+ y = x * (a + b - 2.0) - (a - 1.0);
+ if (y < 0.0)
+ w = incompleteBetaFraction1(a, b, x);
else
- w = incompleteBetaFraction2( a, b, x ) / xc;
+ w = incompleteBetaFraction2(a, b, x) / xc;
- /* Multiply w by the factor
- a b _ _ _
- x (1-x) | (a+b) / ( a | (a) | (b) ) . */
+ /*
+ * Multiply w by the factor a b _ _ _ x (1-x) | (a+b) / ( a | (a) | (b) ) .
+ */
y = a * Math.log(x);
t = b * Math.log(xc);
- if( (a+b) < MAXGAM && Math.abs(y) < MAXLOG && Math.abs(t) < MAXLOG ) {
- t = Math.pow(xc,b);
- t *= Math.pow(x,a);
+ if ((a + b) < MAXGAM && Math.abs(y) < MAXLOG && Math.abs(t) < MAXLOG) {
+ t = Math.pow(xc, b);
+ t *= Math.pow(x, a);
t /= a;
t *= w;
- t *= gamma(a+b) / (gamma(a) * gamma(b));
- if( flag ) {
- if( t <= MACHEP ) t = 1.0 - MACHEP;
- else t = 1.0 - t;
+ t *= gamma(a + b) / (gamma(a) * gamma(b));
+ if (flag) {
+ if (t <= MACHEP)
+ t = 1.0 - MACHEP;
+ else
+ t = 1.0 - t;
}
return t;
}
- /* Resort to logarithms. */
- y += t + lnGamma(a+b) - lnGamma(a) - lnGamma(b);
- y += Math.log(w/a);
- if( y < MINLOG )
+ /* Resort to logarithms. */
+ y += t + lnGamma(a + b) - lnGamma(a) - lnGamma(b);
+ y += Math.log(w / a);
+ if (y < MINLOG)
t = 0.0;
else
t = Math.exp(y);
- if( flag ) {
- if( t <= MACHEP ) t = 1.0 - MACHEP;
- else t = 1.0 - t;
+ if (flag) {
+ if (t <= MACHEP)
+ t = 1.0 - MACHEP;
+ else
+ t = 1.0 - t;
}
return t;
- }
+ }
/**
* Continued fraction expansion #1 for incomplete beta integral.
*/
- public static double incompleteBetaFraction1( double a, double b, double x ) {
+ public static double incompleteBetaFraction1(double a, double b, double x) {
double xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
double k1, k2, k3, k4, k5, k6, k7, k8;
@@ -873,30 +943,32 @@
n = 0;
thresh = 3.0 * MACHEP;
do {
- xk = -( x * k1 * k2 )/( k3 * k4 );
- pk = pkm1 + pkm2 * xk;
- qk = qkm1 + qkm2 * xk;
+ xk = -(x * k1 * k2) / (k3 * k4);
+ pk = pkm1 + pkm2 * xk;
+ qk = qkm1 + qkm2 * xk;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
- xk = ( x * k5 * k6 )/( k7 * k8 );
- pk = pkm1 + pkm2 * xk;
- qk = qkm1 + qkm2 * xk;
+ xk = (x * k5 * k6) / (k7 * k8);
+ pk = pkm1 + pkm2 * xk;
+ qk = qkm1 + qkm2 * xk;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
- if( qk != 0 ) r = pk/qk;
- if( r != 0 ) {
- t = Math.abs( (ans - r)/r );
- ans = r;
- } else
- t = 1.0;
+ if (qk != 0)
+ r = pk / qk;
+ if (r != 0) {
+ t = Math.abs((ans - r) / r);
+ ans = r;
+ } else
+ t = 1.0;
- if( t < thresh ) return ans;
+ if (t < thresh)
+ return ans;
k1 += 1.0;
k2 += 1.0;
@@ -907,27 +979,27 @@
k7 += 2.0;
k8 += 2.0;
- if( (Math.abs(qk) + Math.abs(pk)) > big ) {
- pkm2 *= biginv;
- pkm1 *= biginv;
- qkm2 *= biginv;
- qkm1 *= biginv;
+ if ((Math.abs(qk) + Math.abs(pk)) > big) {
+ pkm2 *= biginv;
+ pkm1 *= biginv;
+ qkm2 *= biginv;
+ qkm1 *= biginv;
}
- if( (Math.abs(qk) < biginv) || (Math.abs(pk) < biginv) ) {
- pkm2 *= big;
- pkm1 *= big;
- qkm2 *= big;
- qkm1 *= big;
+ if ((Math.abs(qk) < biginv) || (Math.abs(pk) < biginv)) {
+ pkm2 *= big;
+ pkm1 *= big;
+ qkm2 *= big;
+ qkm1 *= big;
}
- } while( ++n < 300 );
+ } while (++n < 300);
return ans;
- }
+ }
/**
* Continued fraction expansion #2 for incomplete beta integral.
*/
- public static double incompleteBetaFraction2( double a, double b, double x ) {
+ public static double incompleteBetaFraction2(double a, double b, double x) {
double xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
double k1, k2, k3, k4, k5, k6, k7, k8;
@@ -940,43 +1012,46 @@
k4 = a + 1.0;
k5 = 1.0;
k6 = a + b;
- k7 = a + 1.0;;
+ k7 = a + 1.0;
+ ;
k8 = a + 2.0;
pkm2 = 0.0;
qkm2 = 1.0;
pkm1 = 1.0;
qkm1 = 1.0;
- z = x / (1.0-x);
+ z = x / (1.0 - x);
ans = 1.0;
r = 1.0;
n = 0;
thresh = 3.0 * MACHEP;
do {
- xk = -( z * k1 * k2 )/( k3 * k4 );
- pk = pkm1 + pkm2 * xk;
- qk = qkm1 + qkm2 * xk;
+ xk = -(z * k1 * k2) / (k3 * k4);
+ pk = pkm1 + pkm2 * xk;
+ qk = qkm1 + qkm2 * xk;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
- xk = ( z * k5 * k6 )/( k7 * k8 );
- pk = pkm1 + pkm2 * xk;
- qk = qkm1 + qkm2 * xk;
+ xk = (z * k5 * k6) / (k7 * k8);
+ pk = pkm1 + pkm2 * xk;
+ qk = qkm1 + qkm2 * xk;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;
- if( qk != 0 ) r = pk/qk;
- if( r != 0 ) {
- t = Math.abs( (ans - r)/r );
- ans = r;
+ if (qk != 0)
+ r = pk / qk;
+ if (r != 0) {
+ t = Math.abs((ans - r) / r);
+ ans = r;
} else
- t = 1.0;
+ t = 1.0;
- if( t < thresh ) return ans;
+ if (t < thresh)
+ return ans;
k1 += 1.0;
k2 -= 1.0;
@@ -987,31 +1062,31 @@
k7 += 2.0;
k8 += 2.0;
- if( (Math.abs(qk) + Math.abs(pk)) > big ) {
- pkm2 *= biginv;
- pkm1 *= biginv;
- qkm2 *= biginv;
- qkm1 *= biginv;
+ if ((Math.abs(qk) + Math.abs(pk)) > big) {
+ pkm2 *= biginv;
+ pkm1 *= biginv;
+ qkm2 *= biginv;
+ qkm1 *= biginv;
}
- if( (Math.abs(qk) < biginv) || (Math.abs(pk) < biginv) ) {
- pkm2 *= big;
- pkm1 *= big;
- qkm2 *= big;
- qkm1 *= big;
+ if ((Math.abs(qk) < biginv) || (Math.abs(pk) < biginv)) {
+ pkm2 *= big;
+ pkm1 *= big;
+ qkm2 *= big;
+ qkm1 *= big;
}
- } while( ++n < 300 );
+ } while (++n < 300);
return ans;
}
/**
- * Power series for incomplete beta integral.
- * Use when b*x is small and x not too close to 1.
+ * Power series for incomplete beta integral. Use when b*x is small and x not
+ * too close to 1.
*/
- public static double powerSeries( double a, double b, double x ) {
+ public static double powerSeries(double a, double b, double x) {
double s, t, u, v, n, t1, z, ai;
-
+
ai = 1.0 / a;
u = (1.0 - b) * x;
v = u / (a + 1.0);
@@ -1020,27 +1095,28 @@
n = 2.0;
s = 0.0;
z = MACHEP * ai;
- while( Math.abs(v) > z ) {
+ while (Math.abs(v) > z) {
u = (n - b) * x / n;
t *= u;
v = t / (a + n);
- s += v;
+ s += v;
n += 1.0;
}
s += t1;
s += ai;
u = a * Math.log(x);
- if( (a+b) < MAXGAM && Math.abs(u) < MAXLOG ) {
- t = gamma(a+b)/(gamma(a)*gamma(b));
- s = s * t * Math.pow(x,a);
+ if ((a + b) < MAXGAM && Math.abs(u) < MAXLOG) {
+ t = gamma(a + b) / (gamma(a) * gamma(b));
+ s = s * t * Math.pow(x, a);
} else {
- t = lnGamma(a+b) - lnGamma(a) - lnGamma(b) + u + Math.log(s);
- if( t < MINLOG ) s = 0.0;
- else s = Math.exp(t);
+ t = lnGamma(a + b) - lnGamma(a) - lnGamma(b) + u + Math.log(s);
+ if (t < MINLOG)
+ s = 0.0;
+ else
+ s = Math.exp(t);
}
return s;
}
-
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/StringUtils.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/StringUtils.java
index 6a27206..a14e21e 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/StringUtils.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/StringUtils.java
@@ -24,73 +24,73 @@
/**
* Class implementing some string utility methods.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class StringUtils {
- public static final String newline = System.getProperty("line.separator");
+ public static final String newline = System.getProperty("line.separator");
- public static String doubleToString(double value, int fractionDigits) {
- return doubleToString(value, 0, fractionDigits);
- }
+ public static String doubleToString(double value, int fractionDigits) {
+ return doubleToString(value, 0, fractionDigits);
+ }
- public static String doubleToString(double value, int minFractionDigits,
- int maxFractionDigits) {
- DecimalFormat numberFormat = new DecimalFormat();
- numberFormat.setMinimumFractionDigits(minFractionDigits);
- numberFormat.setMaximumFractionDigits(maxFractionDigits);
- return numberFormat.format(value);
- }
+ public static String doubleToString(double value, int minFractionDigits,
+ int maxFractionDigits) {
+ DecimalFormat numberFormat = new DecimalFormat();
+ numberFormat.setMinimumFractionDigits(minFractionDigits);
+ numberFormat.setMaximumFractionDigits(maxFractionDigits);
+ return numberFormat.format(value);
+ }
- public static void appendNewline(StringBuilder out) {
- out.append(newline);
- }
+ public static void appendNewline(StringBuilder out) {
+ out.append(newline);
+ }
- public static void appendIndent(StringBuilder out, int indent) {
- for (int i = 0; i < indent; i++) {
- out.append(' ');
- }
+ public static void appendIndent(StringBuilder out, int indent) {
+ for (int i = 0; i < indent; i++) {
+ out.append(' ');
}
+ }
- public static void appendIndented(StringBuilder out, int indent, String s) {
- appendIndent(out, indent);
- out.append(s);
- }
+ public static void appendIndented(StringBuilder out, int indent, String s) {
+ appendIndent(out, indent);
+ out.append(s);
+ }
- public static void appendNewlineIndented(StringBuilder out, int indent,
- String s) {
- appendNewline(out);
- appendIndented(out, indent, s);
- }
+ public static void appendNewlineIndented(StringBuilder out, int indent,
+ String s) {
+ appendNewline(out);
+ appendIndented(out, indent, s);
+ }
- public static String secondsToDHMSString(double seconds) {
- if (seconds < 60) {
- return doubleToString(seconds, 2, 2) + 's';
- }
- long secs = (int) (seconds);
- long mins = secs / 60;
- long hours = mins / 60;
- long days = hours / 24;
- secs %= 60;
- mins %= 60;
- hours %= 24;
- StringBuilder result = new StringBuilder();
- if (days > 0) {
- result.append(days);
- result.append('d');
- }
- if ((hours > 0) || (days > 0)) {
- result.append(hours);
- result.append('h');
- }
- if ((hours > 0) || (days > 0) || (mins > 0)) {
- result.append(mins);
- result.append('m');
- }
- result.append(secs);
- result.append('s');
- return result.toString();
+ public static String secondsToDHMSString(double seconds) {
+ if (seconds < 60) {
+ return doubleToString(seconds, 2, 2) + 's';
}
+ long secs = (int) (seconds);
+ long mins = secs / 60;
+ long hours = mins / 60;
+ long days = hours / 24;
+ secs %= 60;
+ mins %= 60;
+ hours %= 24;
+ StringBuilder result = new StringBuilder();
+ if (days > 0) {
+ result.append(days);
+ result.append('d');
+ }
+ if ((hours > 0) || (days > 0)) {
+ result.append(hours);
+ result.append('h');
+ }
+ if ((hours > 0) || (days > 0) || (mins > 0)) {
+ result.append(mins);
+ result.append('m');
+ }
+ result.append(secs);
+ result.append('s');
+ return result.toString();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Utils.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Utils.java
index f627cb8..f406b2c 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Utils.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/core/Utils.java
@@ -1,4 +1,3 @@
-
/*
* Utils.java
@@ -43,10 +42,10 @@
/**
* Class implementing some simple utility methods.
- *
- * @author Eibe Frank
- * @author Yong Wang
- * @author Len Trigg
+ *
+ * @author Eibe Frank
+ * @author Yong Wang
+ * @author Len Trigg
* @author Julien Prados
* @version $Revision: 8080 $
*/
@@ -57,11 +56,12 @@
/** The small deviation allowed in double comparisons. */
public static double SMALL = 1e-6;
-
+
/**
* Tests if the given value codes "missing".
- *
- * @param val the value to be tested
+ *
+ * @param val
+ * the value to be tested
* @return true if val codes "missing"
*/
public static boolean isMissingValue(double val) {
@@ -70,41 +70,42 @@
}
/**
- * Returns the value used to code a missing value. Note that
- * equality tests on this value will always return false, so use
- * isMissingValue(double val) for testing..
- *
+ * Returns the value used to code a missing value. Note that equality tests on
+ * this value will always return false, so use isMissingValue(double val) for
+ * testing..
+ *
* @return the value used as missing value.
*/
public static double missingValue() {
-
+
return Double.NaN;
}
/**
- * Casting an object without "unchecked" compile-time warnings.
- * Use only when absolutely necessary (e.g. when using clone()).
+ * Casting an object without "unchecked" compile-time warnings. Use only when
+ * absolutely necessary (e.g. when using clone()).
*/
@SuppressWarnings("unchecked")
- public static <T> T cast(Object x) {
+ public static <T> T cast(Object x) {
return (T) x;
}
-
-
/**
* Returns the correlation coefficient of two double vectors.
- *
- * @param y1 double vector 1
- * @param y2 double vector 2
- * @param n the length of two double vectors
+ *
+ * @param y1
+ * double vector 1
+ * @param y2
+ * double vector 2
+ * @param n
+ * the length of two double vectors
* @return the correlation coefficient
*/
- public static final double correlation(double y1[],double y2[],int n) {
+ public static final double correlation(double y1[], double y2[], int n) {
int i;
double av1 = 0.0, av2 = 0.0, y11 = 0.0, y22 = 0.0, y12 = 0.0, c;
-
+
if (n <= 1) {
return 1.0;
}
@@ -120,26 +121,28 @@
y12 += (y1[i] - av1) * (y2[i] - av2);
}
if (y11 * y22 == 0.0) {
- c=1.0;
+ c = 1.0;
} else {
c = y12 / Math.sqrt(Math.abs(y11 * y22));
}
-
+
return c;
}
/**
* Removes all occurrences of a string from another string.
- *
- * @param inString the string to remove substrings from.
- * @param substring the substring to remove.
+ *
+ * @param inString
+ * the string to remove substrings from.
+ * @param substring
+ * the substring to remove.
* @return the input string with occurrences of substring removed.
*/
public static String removeSubstring(String inString, String substring) {
StringBuffer result = new StringBuffer();
int oldLoc = 0, loc = 0;
- while ((loc = inString.indexOf(substring, oldLoc))!= -1) {
+ while ((loc = inString.indexOf(substring, oldLoc)) != -1) {
result.append(inString.substring(oldLoc, loc));
oldLoc = loc + substring.length();
}
@@ -148,20 +151,23 @@
}
/**
- * Replaces with a new string, all occurrences of a string from
- * another string.
- *
- * @param inString the string to replace substrings in.
- * @param subString the substring to replace.
- * @param replaceString the replacement substring
+ * Replaces with a new string, all occurrences of a string from another
+ * string.
+ *
+ * @param inString
+ * the string to replace substrings in.
+ * @param subString
+ * the substring to replace.
+ * @param replaceString
+ * the replacement substring
* @return the input string with occurrences of substring replaced.
*/
public static String replaceSubstring(String inString, String subString,
- String replaceString) {
+ String replaceString) {
StringBuffer result = new StringBuffer();
int oldLoc = 0, loc = 0;
- while ((loc = inString.indexOf(subString, oldLoc))!= -1) {
+ while ((loc = inString.indexOf(subString, oldLoc)) != -1) {
result.append(inString.substring(oldLoc, loc));
result.append(replaceString);
oldLoc = loc + subString.length();
@@ -170,110 +176,116 @@
return result.toString();
}
-
/**
- * Pads a string to a specified length, inserting spaces on the left
- * as required. If the string is too long, characters are removed (from
- * the right).
- *
- * @param inString the input string
- * @param length the desired length of the output string
+ * Pads a string to a specified length, inserting spaces on the left as
+ * required. If the string is too long, characters are removed (from the
+ * right).
+ *
+ * @param inString
+ * the input string
+ * @param length
+ * the desired length of the output string
* @return the output string
*/
public static String padLeft(String inString, int length) {
return fixStringLength(inString, length, false);
}
-
+
/**
- * Pads a string to a specified length, inserting spaces on the right
- * as required. If the string is too long, characters are removed (from
- * the right).
- *
- * @param inString the input string
- * @param length the desired length of the output string
+ * Pads a string to a specified length, inserting spaces on the right as
+ * required. If the string is too long, characters are removed (from the
+ * right).
+ *
+ * @param inString
+ * the input string
+ * @param length
+ * the desired length of the output string
* @return the output string
*/
public static String padRight(String inString, int length) {
return fixStringLength(inString, length, true);
}
-
+
/**
- * Pads a string to a specified length, inserting spaces as
- * required. If the string is too long, characters are removed (from
- * the right).
- *
- * @param inString the input string
- * @param length the desired length of the output string
- * @param right true if inserted spaces should be added to the right
+ * Pads a string to a specified length, inserting spaces as required. If the
+ * string is too long, characters are removed (from the right).
+ *
+ * @param inString
+ * the input string
+ * @param length
+ * the desired length of the output string
+ * @param right
+ * true if inserted spaces should be added to the right
* @return the output string
*/
- private static /*@pure@*/ String fixStringLength(String inString, int length,
- boolean right) {
+ private static/* @pure@ */String fixStringLength(String inString, int length,
+ boolean right) {
if (inString.length() < length) {
while (inString.length() < length) {
- inString = (right ? inString.concat(" ") : " ".concat(inString));
+ inString = (right ? inString.concat(" ") : " ".concat(inString));
}
} else if (inString.length() > length) {
inString = inString.substring(0, length);
}
return inString;
}
-
+
/**
* Rounds a double and converts it into String.
- *
- * @param value the double value
- * @param afterDecimalPoint the (maximum) number of digits permitted
- * after the decimal point
+ *
+ * @param value
+ * the double value
+ * @param afterDecimalPoint
+ * the (maximum) number of digits permitted after the decimal point
* @return the double as a formatted string
*/
- public static /*@pure@*/ String doubleToString(double value, int afterDecimalPoint) {
-
+ public static/* @pure@ */String doubleToString(double value, int afterDecimalPoint) {
+
StringBuffer stringBuffer;
double temp;
int dotPosition;
long precisionValue;
-
+
temp = value * Math.pow(10.0, afterDecimalPoint);
if (Math.abs(temp) < Long.MAX_VALUE) {
- precisionValue = (temp > 0) ? (long)(temp + 0.5)
- : -(long)(Math.abs(temp) + 0.5);
+ precisionValue = (temp > 0) ? (long) (temp + 0.5)
+ : -(long) (Math.abs(temp) + 0.5);
if (precisionValue == 0) {
- stringBuffer = new StringBuffer(String.valueOf(0));
+ stringBuffer = new StringBuffer(String.valueOf(0));
} else {
- stringBuffer = new StringBuffer(String.valueOf(precisionValue));
+ stringBuffer = new StringBuffer(String.valueOf(precisionValue));
}
if (afterDecimalPoint == 0) {
- return stringBuffer.toString();
+ return stringBuffer.toString();
}
dotPosition = stringBuffer.length() - afterDecimalPoint;
while (((precisionValue < 0) && (dotPosition < 1)) ||
- (dotPosition < 0)) {
- if (precisionValue < 0) {
- stringBuffer.insert(1, '0');
- } else {
- stringBuffer.insert(0, '0');
- }
- dotPosition++;
+ (dotPosition < 0)) {
+ if (precisionValue < 0) {
+ stringBuffer.insert(1, '0');
+ } else {
+ stringBuffer.insert(0, '0');
+ }
+ dotPosition++;
}
stringBuffer.insert(dotPosition, '.');
if ((precisionValue < 0) && (stringBuffer.charAt(1) == '.')) {
- stringBuffer.insert(1, '0');
+ stringBuffer.insert(1, '0');
} else if (stringBuffer.charAt(0) == '.') {
- stringBuffer.insert(0, '0');
+ stringBuffer.insert(0, '0');
}
int currentPos = stringBuffer.length() - 1;
while ((currentPos > dotPosition) &&
- (stringBuffer.charAt(currentPos) == '0')) {
- stringBuffer.setCharAt(currentPos--, ' ');
+ (stringBuffer.charAt(currentPos) == '0')) {
+ stringBuffer.setCharAt(currentPos--, ' ');
}
if (stringBuffer.charAt(currentPos) == '.') {
- stringBuffer.setCharAt(currentPos, ' ');
+ stringBuffer.setCharAt(currentPos, ' ');
}
-
+
return stringBuffer.toString().trim();
}
return new String("" + value);
@@ -282,20 +294,23 @@
/**
* Rounds a double and converts it into a formatted decimal-justified String.
* Trailing 0's are replaced with spaces.
- *
- * @param value the double value
- * @param width the width of the string
- * @param afterDecimalPoint the number of digits after the decimal point
+ *
+ * @param value
+ * the double value
+ * @param width
+ * the width of the string
+ * @param afterDecimalPoint
+ * the number of digits after the decimal point
* @return the double as a formatted string
*/
- public static /*@pure@*/ String doubleToString(double value, int width,
- int afterDecimalPoint) {
-
+ public static/* @pure@ */String doubleToString(double value, int width,
+ int afterDecimalPoint) {
+
String tempString = doubleToString(value, afterDecimalPoint);
char[] result;
int dotPosition;
- if ((afterDecimalPoint >= width)
+ if ((afterDecimalPoint >= width)
|| (tempString.indexOf('E') != -1)) { // Protects sci notation
return tempString;
}
@@ -310,14 +325,13 @@
// Get position of decimal point and insert decimal point
dotPosition = tempString.indexOf('.');
if (dotPosition == -1) {
- dotPosition = tempString.length();
+ dotPosition = tempString.length();
} else {
- result[width - afterDecimalPoint - 1] = '.';
+ result[width - afterDecimalPoint - 1] = '.';
}
} else {
dotPosition = tempString.length();
}
-
int offset = width - afterDecimalPoint - dotPosition;
if (afterDecimalPoint > 0) {
@@ -345,23 +359,26 @@
/**
* Returns the basic class of an array class (handles multi-dimensional
* arrays).
- * @param c the array to inspect
- * @return the class of the innermost elements
+ *
+ * @param c
+ * the array to inspect
+ * @return the class of the innermost elements
*/
public static Class getArrayClass(Class c) {
- if (c.getComponentType().isArray())
- return getArrayClass(c.getComponentType());
- else
- return c.getComponentType();
+ if (c.getComponentType().isArray())
+ return getArrayClass(c.getComponentType());
+ else
+ return c.getComponentType();
}
/**
- * Returns the dimensions of the given array. Even though the
- * parameter is of type "Object" one can hand over primitve arrays, e.g.
- * int[3] or double[2][4].
- *
- * @param array the array to determine the dimensions for
- * @return the dimensions of the array
+ * Returns the dimensions of the given array. Even though the parameter is of
+ * type "Object" one can hand over primitve arrays, e.g. int[3] or
+ * double[2][4].
+ *
+ * @param array
+ * the array to determine the dimensions for
+ * @return the dimensions of the array
*/
public static int getArrayDimensions(Class array) {
if (array.getComponentType().isArray())
@@ -371,12 +388,13 @@
}
/**
- * Returns the dimensions of the given array. Even though the
- * parameter is of type "Object" one can hand over primitve arrays, e.g.
- * int[3] or double[2][4].
- *
- * @param array the array to determine the dimensions for
- * @return the dimensions of the array
+ * Returns the dimensions of the given array. Even though the parameter is of
+ * type "Object" one can hand over primitve arrays, e.g. int[3] or
+ * double[2][4].
+ *
+ * @param array
+ * the array to determine the dimensions for
+ * @return the dimensions of the array
*/
public static int getArrayDimensions(Object array) {
return getArrayDimensions(array.getClass());
@@ -387,17 +405,18 @@
* parameter is of type "Object" one can hand over primitve arrays, e.g.
* int[3] or double[2][4].
*
- * @param array the array to return in a string representation
- * @return the array as string
+ * @param array
+ * the array to return in a string representation
+ * @return the array as string
*/
public static String arrayToString(Object array) {
- String result;
- int dimensions;
- int i;
+ String result;
+ int dimensions;
+ int i;
- result = "";
+ result = "";
dimensions = getArrayDimensions(array);
-
+
if (dimensions == 0) {
result = "null";
}
@@ -418,30 +437,34 @@
result += "[" + arrayToString(Array.get(array, i)) + "]";
}
}
-
+
return result;
}
/**
* Tests if a is equal to b.
- *
- * @param a a double
- * @param b a double
+ *
+ * @param a
+ * a double
+ * @param b
+ * a double
*/
- public static /*@pure@*/ boolean eq(double a, double b){
-
- return (a - b < SMALL) && (b - a < SMALL);
+ public static/* @pure@ */boolean eq(double a, double b) {
+
+ return (a - b < SMALL) && (b - a < SMALL);
}
/**
* Checks if the given array contains any non-empty options.
- *
- * @param options an array of strings
- * @exception Exception if there are any non-empty options
+ *
+ * @param options
+ * an array of strings
+ * @exception Exception
+ * if there are any non-empty options
*/
- public static void checkForRemainingOptions(String[] options)
- throws Exception {
-
+ public static void checkForRemainingOptions(String[] options)
+ throws Exception {
+
int illegalOptionsFound = 0;
StringBuffer text = new StringBuffer();
@@ -450,143 +473,158 @@
}
for (int i = 0; i < options.length; i++) {
if (options[i].length() > 0) {
- illegalOptionsFound++;
- text.append(options[i] + ' ');
+ illegalOptionsFound++;
+ text.append(options[i] + ' ');
}
}
if (illegalOptionsFound > 0) {
throw new Exception("Illegal options: " + text);
}
}
-
+
/**
- * Checks if the given array contains the flag "-Char". Stops
- * searching at the first marker "--". If the flag is found,
- * it is replaced with the empty string.
- *
- * @param flag the character indicating the flag.
- * @param options the array of strings containing all the options.
+ * Checks if the given array contains the flag "-Char". Stops searching at the
+ * first marker "--". If the flag is found, it is replaced with the empty
+ * string.
+ *
+ * @param flag
+ * the character indicating the flag.
+ * @param options
+ * the array of strings containing all the options.
* @return true if the flag was found
- * @exception Exception if an illegal option was found
+ * @exception Exception
+ * if an illegal option was found
*/
- public static boolean getFlag(char flag, String[] options)
- throws Exception {
-
+ public static boolean getFlag(char flag, String[] options)
+ throws Exception {
+
return getFlag("" + flag, options);
}
-
+
/**
- * Checks if the given array contains the flag "-String". Stops
- * searching at the first marker "--". If the flag is found,
- * it is replaced with the empty string.
- *
- * @param flag the String indicating the flag.
- * @param options the array of strings containing all the options.
+ * Checks if the given array contains the flag "-String". Stops searching at
+ * the first marker "--". If the flag is found, it is replaced with the empty
+ * string.
+ *
+ * @param flag
+ * the String indicating the flag.
+ * @param options
+ * the array of strings containing all the options.
* @return true if the flag was found
- * @exception Exception if an illegal option was found
+ * @exception Exception
+ * if an illegal option was found
*/
- public static boolean getFlag(String flag, String[] options)
- throws Exception {
-
+ public static boolean getFlag(String flag, String[] options)
+ throws Exception {
+
int pos = getOptionPos(flag, options);
if (pos > -1)
options[pos] = "";
-
+
return (pos > -1);
}
/**
- * Gets an option indicated by a flag "-Char" from the given array
- * of strings. Stops searching at the first marker "--". Replaces
- * flag and option with empty strings.
- *
- * @param flag the character indicating the option.
- * @param options the array of strings containing all the options.
+ * Gets an option indicated by a flag "-Char" from the given array of strings.
+ * Stops searching at the first marker "--". Replaces flag and option with
+ * empty strings.
+ *
+ * @param flag
+ * the character indicating the option.
+ * @param options
+ * the array of strings containing all the options.
* @return the indicated option or an empty string
- * @exception Exception if the option indicated by the flag can't be found
+ * @exception Exception
+ * if the option indicated by the flag can't be found
*/
- public static /*@non_null@*/ String getOption(char flag, String[] options)
- throws Exception {
-
+ public static/* @non_null@ */String getOption(char flag, String[] options)
+ throws Exception {
+
return getOption("" + flag, options);
}
/**
- * Gets an option indicated by a flag "-String" from the given array
- * of strings. Stops searching at the first marker "--". Replaces
- * flag and option with empty strings.
- *
- * @param flag the String indicating the option.
- * @param options the array of strings containing all the options.
+ * Gets an option indicated by a flag "-String" from the given array of
+ * strings. Stops searching at the first marker "--". Replaces flag and option
+ * with empty strings.
+ *
+ * @param flag
+ * the String indicating the option.
+ * @param options
+ * the array of strings containing all the options.
* @return the indicated option or an empty string
- * @exception Exception if the option indicated by the flag can't be found
+ * @exception Exception
+ * if the option indicated by the flag can't be found
*/
- public static /*@non_null@*/ String getOption(String flag, String[] options)
- throws Exception {
+ public static/* @non_null@ */String getOption(String flag, String[] options)
+ throws Exception {
String newString;
int i = getOptionPos(flag, options);
if (i > -1) {
if (options[i].equals("-" + flag)) {
- if (i + 1 == options.length) {
- throw new Exception("No value given for -" + flag + " option.");
- }
- options[i] = "";
- newString = new String(options[i + 1]);
- options[i + 1] = "";
- return newString;
+ if (i + 1 == options.length) {
+ throw new Exception("No value given for -" + flag + " option.");
+ }
+ options[i] = "";
+ newString = new String(options[i + 1]);
+ options[i + 1] = "";
+ return newString;
}
if (options[i].charAt(1) == '-') {
- return "";
+ return "";
}
}
-
+
return "";
}
/**
- * Gets the index of an option or flag indicated by a flag "-Char" from
- * the given array of strings. Stops searching at the first marker "--".
- *
- * @param flag the character indicating the option.
- * @param options the array of strings containing all the options.
- * @return the position if found, or -1 otherwise
+ * Gets the index of an option or flag indicated by a flag "-Char" from the
+ * given array of strings. Stops searching at the first marker "--".
+ *
+ * @param flag
+ * the character indicating the option.
+ * @param options
+ * the array of strings containing all the options.
+ * @return the position if found, or -1 otherwise
*/
public static int getOptionPos(char flag, String[] options) {
- return getOptionPos("" + flag, options);
+ return getOptionPos("" + flag, options);
}
/**
- * Gets the index of an option or flag indicated by a flag "-String" from
- * the given array of strings. Stops searching at the first marker "--".
- *
- * @param flag the String indicating the option.
- * @param options the array of strings containing all the options.
- * @return the position if found, or -1 otherwise
+ * Gets the index of an option or flag indicated by a flag "-String" from the
+ * given array of strings. Stops searching at the first marker "--".
+ *
+ * @param flag
+ * the String indicating the option.
+ * @param options
+ * the array of strings containing all the options.
+ * @return the position if found, or -1 otherwise
*/
public static int getOptionPos(String flag, String[] options) {
if (options == null)
return -1;
-
+
for (int i = 0; i < options.length; i++) {
if ((options[i].length() > 0) && (options[i].charAt(0) == '-')) {
- // Check if it is a negative number
- try {
- Double.valueOf(options[i]);
- }
- catch (NumberFormatException e) {
- // found?
- if (options[i].equals("-" + flag))
- return i;
- // did we reach "--"?
- if (options[i].charAt(1) == '-')
- return -1;
- }
+ // Check if it is a negative number
+ try {
+ Double.valueOf(options[i]);
+ } catch (NumberFormatException e) {
+ // found?
+ if (options[i].equals("-" + flag))
+ return i;
+ // did we reach "--"?
+ if (options[i].charAt(1) == '-')
+ return -1;
+ }
}
}
-
+
return -1;
}
@@ -594,65 +632,66 @@
* Quotes a string if it contains special characters.
*
* The following rules are applied:
- *
- * A character is backquoted version of it is one
- * of <tt>" ' % \ \n \r \t</tt>.
- *
+ *
+ * A character is backquoted version of it is one of <tt>" ' % \ \n \r \t</tt>
+ * .
+ *
* A string is enclosed within single quotes if a character has been
- * backquoted using the previous rule above or contains
- * <tt>{ }</tt> or is exactly equal to the strings
- * <tt>, ? space or ""</tt> (empty string).
- *
- * A quoted question mark distinguishes it from the missing value which
- * is represented as an unquoted question mark in arff files.
- *
- * @param string the string to be quoted
- * @return the string (possibly quoted)
- * @see #unquote(String)
+ * backquoted using the previous rule above or contains <tt>{ }</tt> or is
+ * exactly equal to the strings <tt>, ? space or ""</tt> (empty string).
+ *
+ * A quoted question mark distinguishes it from the missing value which is
+ * represented as an unquoted question mark in arff files.
+ *
+ * @param string
+ * the string to be quoted
+ * @return the string (possibly quoted)
+ * @see #unquote(String)
*/
- public static /*@pure@*/ String quote(String string) {
- boolean quote = false;
+ public static/* @pure@ */String quote(String string) {
+ boolean quote = false;
- // backquote the following characters
- if ((string.indexOf('\n') != -1) || (string.indexOf('\r') != -1) ||
- (string.indexOf('\'') != -1) || (string.indexOf('"') != -1) ||
- (string.indexOf('\\') != -1) ||
- (string.indexOf('\t') != -1) || (string.indexOf('%') != -1) ||
- (string.indexOf('\u001E') != -1)) {
- string = backQuoteChars(string);
- quote = true;
- }
+ // backquote the following characters
+ if ((string.indexOf('\n') != -1) || (string.indexOf('\r') != -1) ||
+ (string.indexOf('\'') != -1) || (string.indexOf('"') != -1) ||
+ (string.indexOf('\\') != -1) ||
+ (string.indexOf('\t') != -1) || (string.indexOf('%') != -1) ||
+ (string.indexOf('\u001E') != -1)) {
+ string = backQuoteChars(string);
+ quote = true;
+ }
- // Enclose the string in 's if the string contains a recently added
- // backquote or contains one of the following characters.
- if((quote == true) ||
- (string.indexOf('{') != -1) || (string.indexOf('}') != -1) ||
- (string.indexOf(',') != -1) || (string.equals("?")) ||
- (string.indexOf(' ') != -1) || (string.equals(""))) {
- string = ("'".concat(string)).concat("'");
- }
+ // Enclose the string in 's if the string contains a recently added
+ // backquote or contains one of the following characters.
+ if ((quote == true) ||
+ (string.indexOf('{') != -1) || (string.indexOf('}') != -1) ||
+ (string.indexOf(',') != -1) || (string.equals("?")) ||
+ (string.indexOf(' ') != -1) || (string.equals(""))) {
+ string = ("'".concat(string)).concat("'");
+ }
- return string;
+ return string;
}
/**
* unquotes are previously quoted string (but only if necessary), i.e., it
* removes the single quotes around it. Inverse to quote(String).
*
- * @param string the string to process
- * @return the unquoted string
- * @see #quote(String)
+ * @param string
+ * the string to process
+ * @return the unquoted string
+ * @see #quote(String)
*/
public static String unquote(String string) {
if (string.startsWith("'") && string.endsWith("'")) {
string = string.substring(1, string.length() - 1);
-
- if ((string.indexOf("\\n") != -1) || (string.indexOf("\\r") != -1) ||
- (string.indexOf("\\'") != -1) || (string.indexOf("\\\"") != -1) ||
- (string.indexOf("\\\\") != -1) ||
- (string.indexOf("\\t") != -1) || (string.indexOf("\\%") != -1) ||
- (string.indexOf("\\u001E") != -1)) {
- string = unbackQuoteChars(string);
+
+ if ((string.indexOf("\\n") != -1) || (string.indexOf("\\r") != -1) ||
+ (string.indexOf("\\'") != -1) || (string.indexOf("\\\"") != -1) ||
+ (string.indexOf("\\\\") != -1) ||
+ (string.indexOf("\\t") != -1) || (string.indexOf("\\%") != -1) ||
+ (string.indexOf("\\u001E") != -1)) {
+ string = unbackQuoteChars(string);
}
}
@@ -663,36 +702,37 @@
* Converts carriage returns and new lines in a string into \r and \n.
* Backquotes the following characters: ` " \ \t and %
*
- * @param string the string
- * @return the converted string
- * @see #unbackQuoteChars(String)
+ * @param string
+ * the string
+ * @return the converted string
+ * @see #unbackQuoteChars(String)
*/
- public static /*@pure@*/ String backQuoteChars(String string) {
+ public static/* @pure@ */String backQuoteChars(String string) {
int index;
StringBuffer newStringBuffer;
// replace each of the following characters with the backquoted version
- char charsFind[] = {'\\', '\'', '\t', '\n', '\r', '"', '%',
- '\u001E'};
- String charsReplace[] = {"\\\\", "\\'", "\\t", "\\n", "\\r", "\\\"", "\\%",
- "\\u001E"};
+ char charsFind[] = { '\\', '\'', '\t', '\n', '\r', '"', '%',
+ '\u001E' };
+ String charsReplace[] = { "\\\\", "\\'", "\\t", "\\n", "\\r", "\\\"", "\\%",
+ "\\u001E" };
for (int i = 0; i < charsFind.length; i++) {
- if (string.indexOf(charsFind[i]) != -1 ) {
- newStringBuffer = new StringBuffer();
- while ((index = string.indexOf(charsFind[i])) != -1) {
- if (index > 0) {
- newStringBuffer.append(string.substring(0, index));
- }
- newStringBuffer.append(charsReplace[i]);
- if ((index + 1) < string.length()) {
- string = string.substring(index + 1);
- } else {
- string = "";
- }
- }
- newStringBuffer.append(string);
- string = newStringBuffer.toString();
+ if (string.indexOf(charsFind[i]) != -1) {
+ newStringBuffer = new StringBuffer();
+ while ((index = string.indexOf(charsFind[i])) != -1) {
+ if (index > 0) {
+ newStringBuffer.append(string.substring(0, index));
+ }
+ newStringBuffer.append(charsReplace[i]);
+ if ((index + 1) < string.length()) {
+ string = string.substring(index + 1);
+ } else {
+ string = "";
+ }
+ }
+ newStringBuffer.append(string);
+ string = newStringBuffer.toString();
}
}
@@ -701,8 +741,9 @@
/**
* Converts carriage returns and new lines in a string into \r and \n.
- *
- * @param string the string
+ *
+ * @param string
+ * the string
* @return the converted string
*/
public static String convertNewLines(String string) {
@@ -712,14 +753,14 @@
StringBuffer newStringBuffer = new StringBuffer();
while ((index = string.indexOf('\n')) != -1) {
if (index > 0) {
- newStringBuffer.append(string.substring(0, index));
+ newStringBuffer.append(string.substring(0, index));
}
newStringBuffer.append('\\');
newStringBuffer.append('n');
if ((index + 1) < string.length()) {
- string = string.substring(index + 1);
+ string = string.substring(index + 1);
} else {
- string = "";
+ string = "";
}
}
newStringBuffer.append(string);
@@ -729,14 +770,14 @@
newStringBuffer = new StringBuffer();
while ((index = string.indexOf('\r')) != -1) {
if (index > 0) {
- newStringBuffer.append(string.substring(0, index));
+ newStringBuffer.append(string.substring(0, index));
}
newStringBuffer.append('\\');
newStringBuffer.append('r');
- if ((index + 1) < string.length()){
- string = string.substring(index + 1);
+ if ((index + 1) < string.length()) {
+ string = string.substring(index + 1);
} else {
- string = "";
+ string = "";
}
}
newStringBuffer.append(string);
@@ -746,7 +787,8 @@
/**
* Reverts \r and \n in a string into carriage returns and new lines.
*
- * @param string the string
+ * @param string
+ * the string
* @return the converted string
*/
public static String revertNewLines(String string) {
@@ -756,13 +798,13 @@
StringBuffer newStringBuffer = new StringBuffer();
while ((index = string.indexOf("\\n")) != -1) {
if (index > 0) {
- newStringBuffer.append(string.substring(0, index));
+ newStringBuffer.append(string.substring(0, index));
}
newStringBuffer.append('\n');
if ((index + 2) < string.length()) {
- string = string.substring(index + 2);
+ string = string.substring(index + 2);
} else {
- string = "";
+ string = "";
}
}
newStringBuffer.append(string);
@@ -772,168 +814,176 @@
newStringBuffer = new StringBuffer();
while ((index = string.indexOf("\\r")) != -1) {
if (index > 0) {
- newStringBuffer.append(string.substring(0, index));
+ newStringBuffer.append(string.substring(0, index));
}
newStringBuffer.append('\r');
- if ((index + 2) < string.length()){
- string = string.substring(index + 2);
+ if ((index + 2) < string.length()) {
+ string = string.substring(index + 2);
} else {
- string = "";
+ string = "";
}
}
newStringBuffer.append(string);
-
+
return newStringBuffer.toString();
}
/**
- * Returns the secondary set of options (if any) contained in
- * the supplied options array. The secondary set is defined to
- * be any options after the first "--". These options are removed from
- * the original options array.
- *
- * @param options the input array of options
+ * Returns the secondary set of options (if any) contained in the supplied
+ * options array. The secondary set is defined to be any options after the
+ * first "--". These options are removed from the original options array.
+ *
+ * @param options
+ * the input array of options
* @return the array of secondary options
*/
public static String[] partitionOptions(String[] options) {
for (int i = 0; i < options.length; i++) {
if (options[i].equals("--")) {
- options[i++] = "";
- String[] result = new String [options.length - i];
- for (int j = i; j < options.length; j++) {
- result[j - i] = options[j];
- options[j] = "";
- }
- return result;
+ options[i++] = "";
+ String[] result = new String[options.length - i];
+ for (int j = i; j < options.length; j++) {
+ result[j - i] = options[j];
+ options[j] = "";
+ }
+ return result;
}
}
- return new String [0];
+ return new String[0];
}
-
+
/**
- * The inverse operation of backQuoteChars().
- * Converts back-quoted carriage returns and new lines in a string
- * to the corresponding character ('\r' and '\n').
- * Also "un"-back-quotes the following characters: ` " \ \t and %
- *
- * @param string the string
- * @return the converted string
- * @see #backQuoteChars(String)
+ * The inverse operation of backQuoteChars(). Converts back-quoted carriage
+ * returns and new lines in a string to the corresponding character ('\r' and
+ * '\n'). Also "un"-back-quotes the following characters: ` " \ \t and %
+ *
+ * @param string
+ * the string
+ * @return the converted string
+ * @see #backQuoteChars(String)
*/
public static String unbackQuoteChars(String string) {
int index;
StringBuffer newStringBuffer;
-
+
// replace each of the following characters with the backquoted version
- String charsFind[] = {"\\\\", "\\'", "\\t", "\\n", "\\r", "\\\"", "\\%",
- "\\u001E"};
- char charsReplace[] = {'\\', '\'', '\t', '\n', '\r', '"', '%',
- '\u001E'};
+ String charsFind[] = { "\\\\", "\\'", "\\t", "\\n", "\\r", "\\\"", "\\%",
+ "\\u001E" };
+ char charsReplace[] = { '\\', '\'', '\t', '\n', '\r', '"', '%',
+ '\u001E' };
int pos[] = new int[charsFind.length];
- int curPos;
-
+ int curPos;
+
String str = new String(string);
newStringBuffer = new StringBuffer();
while (str.length() > 0) {
// get positions and closest character to replace
curPos = str.length();
- index = -1;
+ index = -1;
for (int i = 0; i < pos.length; i++) {
- pos[i] = str.indexOf(charsFind[i]);
- if ( (pos[i] > -1) && (pos[i] < curPos) ) {
- index = i;
- curPos = pos[i];
- }
+ pos[i] = str.indexOf(charsFind[i]);
+ if ((pos[i] > -1) && (pos[i] < curPos)) {
+ index = i;
+ curPos = pos[i];
+ }
}
-
+
// replace character if found, otherwise finished
if (index == -1) {
- newStringBuffer.append(str);
- str = "";
+ newStringBuffer.append(str);
+ str = "";
}
else {
- newStringBuffer.append(str.substring(0, pos[index]));
- newStringBuffer.append(charsReplace[index]);
- str = str.substring(pos[index] + charsFind[index].length());
+ newStringBuffer.append(str.substring(0, pos[index]));
+ newStringBuffer.append(charsReplace[index]);
+ str = str.substring(pos[index] + charsFind[index].length());
}
}
return newStringBuffer.toString();
- }
-
+ }
+
/**
- * Split up a string containing options into an array of strings,
- * one for each option.
- *
- * @param quotedOptionString the string containing the options
- * @return the array of options
- * @throws Exception in case of an unterminated string, unknown character or
- * a parse error
+ * Split up a string containing options into an array of strings, one for each
+ * option.
+ *
+ * @param quotedOptionString
+ * the string containing the options
+ * @return the array of options
+ * @throws Exception
+ * in case of an unterminated string, unknown character or a parse
+ * error
*/
- public static String[] splitOptions(String quotedOptionString) throws Exception{
+ public static String[] splitOptions(String quotedOptionString) throws Exception {
Vector<String> optionsVec = new Vector<String>();
String str = new String(quotedOptionString);
int i;
-
- while (true){
- //trimLeft
+ while (true) {
+
+ // trimLeft
i = 0;
- while ((i < str.length()) && (Character.isWhitespace(str.charAt(i)))) i++;
+ while ((i < str.length()) && (Character.isWhitespace(str.charAt(i))))
+ i++;
str = str.substring(i);
-
- //stop when str is empty
- if (str.length() == 0) break;
-
- //if str start with a double quote
- if (str.charAt(0) == '"'){
-
- //find the first not anti-slached double quote
- i = 1;
- while(i < str.length()){
- if (str.charAt(i) == str.charAt(0)) break;
- if (str.charAt(i) == '\\'){
- i += 1;
- if (i >= str.length())
- throw new Exception("String should not finish with \\");
- }
- i += 1;
- }
- if (i >= str.length()) throw new Exception("Quote parse error.");
-
- //add the founded string to the option vector (without quotes)
- String optStr = str.substring(1,i);
- optStr = unbackQuoteChars(optStr);
- optionsVec.addElement(optStr);
- str = str.substring(i+1);
+
+ // stop when str is empty
+ if (str.length() == 0)
+ break;
+
+ // if str start with a double quote
+ if (str.charAt(0) == '"') {
+
+ // find the first not anti-slached double quote
+ i = 1;
+ while (i < str.length()) {
+ if (str.charAt(i) == str.charAt(0))
+ break;
+ if (str.charAt(i) == '\\') {
+ i += 1;
+ if (i >= str.length())
+ throw new Exception("String should not finish with \\");
+ }
+ i += 1;
+ }
+ if (i >= str.length())
+ throw new Exception("Quote parse error.");
+
+ // add the founded string to the option vector (without quotes)
+ String optStr = str.substring(1, i);
+ optStr = unbackQuoteChars(optStr);
+ optionsVec.addElement(optStr);
+ str = str.substring(i + 1);
} else {
- //find first whiteSpace
- i=0;
- while((i < str.length()) && (!Character.isWhitespace(str.charAt(i)))) i++;
-
- //add the founded string to the option vector
- String optStr = str.substring(0,i);
- optionsVec.addElement(optStr);
- str = str.substring(i);
+ // find first whiteSpace
+ i = 0;
+ while ((i < str.length()) && (!Character.isWhitespace(str.charAt(i))))
+ i++;
+
+ // add the founded string to the option vector
+ String optStr = str.substring(0, i);
+ optionsVec.addElement(optStr);
+ str = str.substring(i);
}
}
-
- //convert optionsVec to an array of String
+
+ // convert optionsVec to an array of String
String[] options = new String[optionsVec.size()];
for (i = 0; i < optionsVec.size(); i++) {
- options[i] = (String)optionsVec.elementAt(i);
+ options[i] = (String) optionsVec.elementAt(i);
}
return options;
- }
+ }
/**
- * Joins all the options in an option array into a single string,
- * as might be used on the command line.
- *
- * @param optionArray the array of options
+ * Joins all the options in an option array into a single string, as might be
+ * used on the command line.
+ *
+ * @param optionArray
+ * the array of options
* @return the string containing all options.
*/
public static String joinOptions(String[] optionArray) {
@@ -941,35 +991,35 @@
String optionString = "";
for (int i = 0; i < optionArray.length; i++) {
if (optionArray[i].equals("")) {
- continue;
+ continue;
}
boolean escape = false;
for (int n = 0; n < optionArray[i].length(); n++) {
- if (Character.isWhitespace(optionArray[i].charAt(n))) {
- escape = true;
- break;
- }
+ if (Character.isWhitespace(optionArray[i].charAt(n))) {
+ escape = true;
+ break;
+ }
}
if (escape) {
- optionString += '"' + backQuoteChars(optionArray[i]) + '"';
+ optionString += '"' + backQuoteChars(optionArray[i]) + '"';
} else {
- optionString += optionArray[i];
+ optionString += optionArray[i];
}
optionString += " ";
}
return optionString.trim();
}
-
-
+
/**
* Computes entropy for an array of integers.
- *
- * @param counts array of counts
- * @return - a log2 a - b log2 b - c log2 c + (a+b+c) log2 (a+b+c)
- * when given array [a b c]
+ *
+ * @param counts
+ * array of counts
+ * @return - a log2 a - b log2 b - c log2 c + (a+b+c) log2 (a+b+c) when given
+ * array [a b c]
*/
- public static /*@pure@*/ double info(int counts[]) {
-
+ public static/* @pure@ */double info(int counts[]) {
+
int total = 0;
double x = 0;
for (int j = 0; j < counts.length; j++) {
@@ -981,59 +1031,69 @@
/**
* Tests if a is smaller or equal to b.
- *
- * @param a a double
- * @param b a double
+ *
+ * @param a
+ * a double
+ * @param b
+ * a double
*/
- public static /*@pure@*/ boolean smOrEq(double a,double b) {
-
- return (a-b < SMALL);
+ public static/* @pure@ */boolean smOrEq(double a, double b) {
+
+ return (a - b < SMALL);
}
/**
* Tests if a is greater or equal to b.
- *
- * @param a a double
- * @param b a double
+ *
+ * @param a
+ * a double
+ * @param b
+ * a double
*/
- public static /*@pure@*/ boolean grOrEq(double a,double b) {
-
- return (b-a < SMALL);
+ public static/* @pure@ */boolean grOrEq(double a, double b) {
+
+ return (b - a < SMALL);
}
-
+
/**
* Tests if a is smaller than b.
- *
- * @param a a double
- * @param b a double
+ *
+ * @param a
+ * a double
+ * @param b
+ * a double
*/
- public static /*@pure@*/ boolean sm(double a,double b) {
-
- return (b-a > SMALL);
+ public static/* @pure@ */boolean sm(double a, double b) {
+
+ return (b - a > SMALL);
}
/**
* Tests if a is greater than b.
- *
- * @param a a double
- * @param b a double
+ *
+ * @param a
+ * a double
+ * @param b
+ * a double
*/
- public static /*@pure@*/ boolean gr(double a,double b) {
-
- return (a-b > SMALL);
+ public static/* @pure@ */boolean gr(double a, double b) {
+
+ return (a - b > SMALL);
}
/**
* Returns the kth-smallest value in the array.
- *
- * @param array the array of integers
- * @param k the value of k
+ *
+ * @param array
+ * the array of integers
+ * @param k
+ * the value of k
* @return the kth-smallest value
*/
public static double kthSmallestValue(int[] array, int k) {
int[] index = new int[array.length];
-
+
for (int i = 0; i < index.length; i++) {
index[i] = i;
}
@@ -1043,15 +1103,17 @@
/**
* Returns the kth-smallest value in the array
- *
- * @param array the array of double
- * @param k the value of k
+ *
+ * @param array
+ * the array of double
+ * @param k
+ * the value of k
* @return the kth-smallest value
*/
public static double kthSmallestValue(double[] array, int k) {
int[] index = new int[array.length];
-
+
for (int i = 0; i < index.length; i++) {
index[i] = i;
}
@@ -1061,31 +1123,33 @@
/**
* Returns the logarithm of a for base 2.
- *
- * @param a a double
- * @return the logarithm for base 2
+ *
+ * @param a
+ * a double
+ * @return the logarithm for base 2
*/
- public static /*@pure@*/ double log2(double a) {
-
+ public static/* @pure@ */double log2(double a) {
+
return Math.log(a) / log2;
}
/**
- * Returns index of maximum element in a given
- * array of doubles. First maximum is returned.
- *
- * @param doubles the array of doubles
+ * Returns index of maximum element in a given array of doubles. First maximum
+ * is returned.
+ *
+ * @param doubles
+ * the array of doubles
* @return the index of the maximum element
*/
- public static /*@pure@*/ int maxIndex(double[] doubles) {
+ public static/* @pure@ */int maxIndex(double[] doubles) {
double maximum = 0;
int maxIndex = 0;
for (int i = 0; i < doubles.length; i++) {
if ((i == 0) || (doubles[i] > maximum)) {
- maxIndex = i;
- maximum = doubles[i];
+ maxIndex = i;
+ maximum = doubles[i];
}
}
@@ -1093,21 +1157,22 @@
}
/**
- * Returns index of maximum element in a given
- * array of integers. First maximum is returned.
- *
- * @param ints the array of integers
+ * Returns index of maximum element in a given array of integers. First
+ * maximum is returned.
+ *
+ * @param ints
+ * the array of integers
* @return the index of the maximum element
*/
- public static /*@pure@*/ int maxIndex(int[] ints) {
+ public static/* @pure@ */int maxIndex(int[] ints) {
int maximum = 0;
int maxIndex = 0;
for (int i = 0; i < ints.length; i++) {
if ((i == 0) || (ints[i] > maximum)) {
- maxIndex = i;
- maximum = ints[i];
+ maxIndex = i;
+ maximum = ints[i];
}
}
@@ -1116,12 +1181,13 @@
/**
* Computes the mean for an array of doubles.
- *
- * @param vector the array
+ *
+ * @param vector
+ * the array
* @return the mean
*/
- public static /*@pure@*/ double mean(double[] vector) {
-
+ public static/* @pure@ */double mean(double[] vector) {
+
double sum = 0;
if (vector.length == 0) {
@@ -1134,21 +1200,22 @@
}
/**
- * Returns index of minimum element in a given
- * array of integers. First minimum is returned.
- *
- * @param ints the array of integers
+ * Returns index of minimum element in a given array of integers. First
+ * minimum is returned.
+ *
+ * @param ints
+ * the array of integers
* @return the index of the minimum element
*/
- public static /*@pure@*/ int minIndex(int[] ints) {
+ public static/* @pure@ */int minIndex(int[] ints) {
int minimum = 0;
int minIndex = 0;
for (int i = 0; i < ints.length; i++) {
if ((i == 0) || (ints[i] < minimum)) {
- minIndex = i;
- minimum = ints[i];
+ minIndex = i;
+ minimum = ints[i];
}
}
@@ -1156,21 +1223,22 @@
}
/**
- * Returns index of minimum element in a given
- * array of doubles. First minimum is returned.
- *
- * @param doubles the array of doubles
+ * Returns index of minimum element in a given array of doubles. First minimum
+ * is returned.
+ *
+ * @param doubles
+ * the array of doubles
* @return the index of the minimum element
*/
- public static /*@pure@*/ int minIndex(double[] doubles) {
+ public static/* @pure@ */int minIndex(double[] doubles) {
double minimum = 0;
int minIndex = 0;
for (int i = 0; i < doubles.length; i++) {
if ((i == 0) || (doubles[i] < minimum)) {
- minIndex = i;
- minimum = doubles[i];
+ minIndex = i;
+ minimum = doubles[i];
}
}
@@ -1179,9 +1247,11 @@
/**
* Normalizes the doubles in the array by their sum.
- *
- * @param doubles the array of double
- * @exception IllegalArgumentException if sum is Zero or NaN
+ *
+ * @param doubles
+ * the array of double
+ * @exception IllegalArgumentException
+ * if sum is Zero or NaN
*/
public static void normalize(double[] doubles) {
@@ -1194,10 +1264,13 @@
/**
* Normalizes the doubles in the array using the given value.
- *
- * @param doubles the array of double
- * @param sum the value by which the doubles are to be normalized
- * @exception IllegalArgumentException if sum is zero or NaN
+ *
+ * @param doubles
+ * the array of double
+ * @param sum
+ * the value by which the doubles are to be normalized
+ * @exception IllegalArgumentException
+ * if sum is zero or NaN
*/
public static void normalize(double[] doubles, double sum) {
@@ -1214,12 +1287,13 @@
}
/**
- * Converts an array containing the natural logarithms of
- * probabilities stored in a vector back into probabilities.
- * The probabilities are assumed to sum to one.
- *
- * @param a an array holding the natural logarithms of the probabilities
- * @return the converted array
+ * Converts an array containing the natural logarithms of probabilities stored
+ * in a vector back into probabilities. The probabilities are assumed to sum
+ * to one.
+ *
+ * @param a
+ * an array holding the natural logarithms of the probabilities
+ * @return the converted array
*/
public static double[] logs2probs(double[] a) {
@@ -1227,7 +1301,7 @@
double sum = 0.0;
double[] result = new double[a.length];
- for(int i = 0; i < a.length; i++) {
+ for (int i = 0; i < a.length; i++) {
result[i] = Math.exp(a[i] - max);
sum += result[i];
}
@@ -1235,51 +1309,54 @@
normalize(result, sum);
return result;
- }
+ }
/**
* Returns the log-odds for a given probabilitiy.
- *
- * @param prob the probabilitiy
- *
- * @return the log-odds after the probability has been mapped to
- * [Utils.SMALL, 1-Utils.SMALL]
+ *
+ * @param prob
+ * the probabilitiy
+ *
+ * @return the log-odds after the probability has been mapped to [Utils.SMALL,
+ * 1-Utils.SMALL]
*/
- public static /*@pure@*/ double probToLogOdds(double prob) {
+ public static/* @pure@ */double probToLogOdds(double prob) {
if (gr(prob, 1) || (sm(prob, 0))) {
throw new IllegalArgumentException("probToLogOdds: probability must " +
- "be in [0,1] "+prob);
+ "be in [0,1] " + prob);
}
double p = SMALL + (1.0 - 2 * SMALL) * prob;
return Math.log(p / (1 - p));
}
/**
- * Rounds a double to the next nearest integer value. The JDK version
- * of it doesn't work properly.
- *
- * @param value the double value
+ * Rounds a double to the next nearest integer value. The JDK version of it
+ * doesn't work properly.
+ *
+ * @param value
+ * the double value
* @return the resulting integer value
*/
- public static /*@pure@*/ int round(double value) {
+ public static/* @pure@ */int round(double value) {
int roundedValue = value > 0
- ? (int)(value + 0.5)
- : -(int)(Math.abs(value) + 0.5);
-
+ ? (int) (value + 0.5)
+ : -(int) (Math.abs(value) + 0.5);
+
return roundedValue;
}
/**
* Rounds a double to the next nearest integer value in a probabilistic
- * fashion (e.g. 0.8 has a 20% chance of being rounded down to 0 and a
- * 80% chance of being rounded up to 1). In the limit, the average of
- * the rounded numbers generated by this procedure should converge to
- * the original double.
- *
- * @param value the double value
- * @param rand the random number generator
+ * fashion (e.g. 0.8 has a 20% chance of being rounded down to 0 and a 80%
+ * chance of being rounded up to 1). In the limit, the average of the rounded
+ * numbers generated by this procedure should converge to the original double.
+ *
+ * @param value
+ * the double value
+ * @param rand
+ * the random number generator
* @return the resulting integer value
*/
public static int probRound(double value, Random rand) {
@@ -1288,52 +1365,54 @@
double lower = Math.floor(value);
double prob = value - lower;
if (rand.nextDouble() < prob) {
- return (int)lower + 1;
+ return (int) lower + 1;
} else {
- return (int)lower;
+ return (int) lower;
}
} else {
double lower = Math.floor(Math.abs(value));
double prob = Math.abs(value) - lower;
if (rand.nextDouble() < prob) {
- return -((int)lower + 1);
+ return -((int) lower + 1);
} else {
- return -(int)lower;
+ return -(int) lower;
}
}
}
/**
* Rounds a double to the given number of decimal places.
- *
- * @param value the double value
- * @param afterDecimalPoint the number of digits after the decimal point
+ *
+ * @param value
+ * the double value
+ * @param afterDecimalPoint
+ * the number of digits after the decimal point
* @return the double rounded to the given precision
*/
- public static /*@pure@*/ double roundDouble(double value,int afterDecimalPoint) {
+ public static/* @pure@ */double roundDouble(double value, int afterDecimalPoint) {
- double mask = Math.pow(10.0, (double)afterDecimalPoint);
+ double mask = Math.pow(10.0, (double) afterDecimalPoint);
- return (double)(Math.round(value * mask)) / mask;
+ return (double) (Math.round(value * mask)) / mask;
}
/**
- * Sorts a given array of integers in ascending order and returns an
- * array of integers with the positions of the elements of the original
- * array in the sorted array. The sort is stable. (Equal elements remain
- * in their original order.)
- *
- * @param array this array is not changed by the method!
- * @return an array of integers with the positions in the sorted
- * array.
+ * Sorts a given array of integers in ascending order and returns an array of
+ * integers with the positions of the elements of the original array in the
+ * sorted array. The sort is stable. (Equal elements remain in their original
+ * order.)
+ *
+ * @param array
+ * this array is not changed by the method!
+ * @return an array of integers with the positions in the sorted array.
*/
- public static /*@pure@*/ int[] sort(int[] array) {
+ public static/* @pure@ */int[] sort(int[] array) {
int[] index = new int[array.length];
int[] newIndex = new int[array.length];
int[] helpIndex;
int numEqual;
-
+
for (int i = 0; i < index.length; i++) {
index[i] = i;
}
@@ -1344,44 +1423,42 @@
while (i < index.length) {
numEqual = 1;
for (int j = i + 1; ((j < index.length)
- && (array[index[i]] == array[index[j]]));
- j++) {
- numEqual++;
+ && (array[index[i]] == array[index[j]])); j++) {
+ numEqual++;
}
if (numEqual > 1) {
- helpIndex = new int[numEqual];
- for (int j = 0; j < numEqual; j++) {
- helpIndex[j] = i + j;
- }
- quickSort(index, helpIndex, 0, numEqual - 1);
- for (int j = 0; j < numEqual; j++) {
- newIndex[i + j] = index[helpIndex[j]];
- }
- i += numEqual;
+ helpIndex = new int[numEqual];
+ for (int j = 0; j < numEqual; j++) {
+ helpIndex[j] = i + j;
+ }
+ quickSort(index, helpIndex, 0, numEqual - 1);
+ for (int j = 0; j < numEqual; j++) {
+ newIndex[i + j] = index[helpIndex[j]];
+ }
+ i += numEqual;
} else {
- newIndex[i] = index[i];
- i++;
+ newIndex[i] = index[i];
+ i++;
}
}
return newIndex;
}
/**
- * Sorts a given array of doubles in ascending order and returns an
- * array of integers with the positions of the elements of the
- * original array in the sorted array. NOTE THESE CHANGES: the sort
- * is no longer stable and it doesn't use safe floating-point
- * comparisons anymore. Occurrences of Double.NaN are treated as
- * Double.MAX_VALUE
- *
- * @param array this array is not changed by the method!
- * @return an array of integers with the positions in the sorted
- * array.
+ * Sorts a given array of doubles in ascending order and returns an array of
+ * integers with the positions of the elements of the original array in the
+ * sorted array. NOTE THESE CHANGES: the sort is no longer stable and it
+ * doesn't use safe floating-point comparisons anymore. Occurrences of
+ * Double.NaN are treated as Double.MAX_VALUE
+ *
+ * @param array
+ * this array is not changed by the method!
+ * @return an array of integers with the positions in the sorted array.
*/
- public static /*@pure@*/ int[] sort(/*@non_null@*/ double[] array) {
+ public static/* @pure@ */int[] sort(/* @non_null@ */double[] array) {
int[] index = new int[array.length];
- array = (double[])array.clone();
+ array = (double[]) array.clone();
for (int i = 0; i < index.length; i++) {
index[i] = i;
if (Double.isNaN(array[i])) {
@@ -1393,51 +1470,50 @@
}
/**
- * Sorts a given array of doubles in ascending order and returns an
- * array of integers with the positions of the elements of the original
- * array in the sorted array. The sort is stable (Equal elements remain
- * in their original order.) Occurrences of Double.NaN are treated as
- * Double.MAX_VALUE
- *
- * @param array this array is not changed by the method!
- * @return an array of integers with the positions in the sorted
- * array.
+ * Sorts a given array of doubles in ascending order and returns an array of
+ * integers with the positions of the elements of the original array in the
+ * sorted array. The sort is stable (Equal elements remain in their original
+ * order.) Occurrences of Double.NaN are treated as Double.MAX_VALUE
+ *
+ * @param array
+ * this array is not changed by the method!
+ * @return an array of integers with the positions in the sorted array.
*/
- public static /*@pure@*/ int[] stableSort(double[] array){
+ public static/* @pure@ */int[] stableSort(double[] array) {
int[] index = new int[array.length];
int[] newIndex = new int[array.length];
int[] helpIndex;
int numEqual;
-
- array = (double[])array.clone();
+
+ array = (double[]) array.clone();
for (int i = 0; i < index.length; i++) {
index[i] = i;
if (Double.isNaN(array[i])) {
array[i] = Double.MAX_VALUE;
}
}
- quickSort(array,index,0,array.length-1);
+ quickSort(array, index, 0, array.length - 1);
// Make sort stable
int i = 0;
while (i < index.length) {
numEqual = 1;
- for (int j = i+1; ((j < index.length) && Utils.eq(array[index[i]],
- array[index[j]])); j++)
- numEqual++;
+ for (int j = i + 1; ((j < index.length) && Utils.eq(array[index[i]],
+ array[index[j]])); j++)
+ numEqual++;
if (numEqual > 1) {
- helpIndex = new int[numEqual];
- for (int j = 0; j < numEqual; j++)
- helpIndex[j] = i+j;
- quickSort(index, helpIndex, 0, numEqual-1);
- for (int j = 0; j < numEqual; j++)
- newIndex[i+j] = index[helpIndex[j]];
- i += numEqual;
+ helpIndex = new int[numEqual];
+ for (int j = 0; j < numEqual; j++)
+ helpIndex[j] = i + j;
+ quickSort(index, helpIndex, 0, numEqual - 1);
+ for (int j = 0; j < numEqual; j++)
+ newIndex[i + j] = index[helpIndex[j]];
+ i += numEqual;
} else {
- newIndex[i] = index[i];
- i++;
+ newIndex[i] = index[i];
+ i++;
}
}
@@ -1446,12 +1522,13 @@
/**
* Computes the variance for an array of doubles.
- *
- * @param vector the array
+ *
+ * @param vector
+ * the array
* @return the variance
*/
- public static /*@pure@*/ double variance(double[] vector) {
-
+ public static/* @pure@ */double variance(double[] vector) {
+
double sum = 0, sumSquared = 0;
if (vector.length <= 1) {
@@ -1461,8 +1538,8 @@
sum += vector[i];
sumSquared += (vector[i] * vector[i]);
}
- double result = (sumSquared - (sum * sum / (double) vector.length)) /
- (double) (vector.length - 1);
+ double result = (sumSquared - (sum * sum / (double) vector.length)) /
+ (double) (vector.length - 1);
// We don't like negative variance
if (result < 0) {
@@ -1474,11 +1551,12 @@
/**
* Computes the sum of the elements of an array of doubles.
- *
- * @param doubles the array of double
+ *
+ * @param doubles
+ * the array of double
* @return the sum of the elements
*/
- public static /*@pure@*/ double sum(double[] doubles) {
+ public static/* @pure@ */double sum(double[] doubles) {
double sum = 0;
@@ -1490,11 +1568,12 @@
/**
* Computes the sum of the elements of an array of integers.
- *
- * @param ints the array of integers
+ *
+ * @param ints
+ * the array of integers
* @return the sum of the elements
*/
- public static /*@pure@*/ int sum(int[] ints) {
+ public static/* @pure@ */int sum(int[] ints) {
int sum = 0;
@@ -1506,12 +1585,13 @@
/**
* Returns c*log2(c) for a given integer value c.
- *
- * @param c an integer value
+ *
+ * @param c
+ * an integer value
* @return c*log2(c) (but is careful to return 0 if c is 0)
*/
- public static /*@pure@*/ double xlogx(int c) {
-
+ public static/* @pure@ */double xlogx(int c) {
+
if (c == 0) {
return 0.0;
}
@@ -1521,16 +1601,20 @@
/**
* Partitions the instances around a pivot. Used by quicksort and
* kthSmallestValue.
- *
- * @param array the array of doubles to be sorted
- * @param index the index into the array of doubles
- * @param l the first index of the subset
- * @param r the last index of the subset
- *
+ *
+ * @param array
+ * the array of doubles to be sorted
+ * @param index
+ * the index into the array of doubles
+ * @param l
+ * the first index of the subset
+ * @param r
+ * the last index of the subset
+ *
* @return the index of the middle element
*/
private static int partition(double[] array, int[] index, int l, int r) {
-
+
double pivot = array[index[(l + r) / 2]];
int help;
@@ -1551,7 +1635,7 @@
}
if ((l == r) && (array[index[r]] > pivot)) {
r--;
- }
+ }
return r;
}
@@ -1559,16 +1643,20 @@
/**
* Partitions the instances around a pivot. Used by quicksort and
* kthSmallestValue.
- *
- * @param array the array of integers to be sorted
- * @param index the index into the array of integers
- * @param l the first index of the subset
- * @param r the last index of the subset
- *
+ *
+ * @param array
+ * the array of integers to be sorted
+ * @param index
+ * the index into the array of integers
+ * @param l
+ * the first index of the subset
+ * @param r
+ * the last index of the subset
+ *
* @return the index of the middle element
*/
private static int partition(int[] array, int[] index, int l, int r) {
-
+
double pivot = array[index[(l + r) / 2]];
int help;
@@ -1589,26 +1677,30 @@
}
if ((l == r) && (array[index[r]] > pivot)) {
r--;
- }
+ }
return r;
}
-
+
/**
- * Implements quicksort according to Manber's "Introduction to
- * Algorithms".
- *
- * @param array the array of doubles to be sorted
- * @param index the index into the array of doubles
- * @param left the first index of the subset to be sorted
- * @param right the last index of the subset to be sorted
+ * Implements quicksort according to Manber's "Introduction to Algorithms".
+ *
+ * @param array
+ * the array of doubles to be sorted
+ * @param index
+ * the index into the array of doubles
+ * @param left
+ * the first index of the subset to be sorted
+ * @param right
+ * the last index of the subset to be sorted
*/
- //@ requires 0 <= first && first <= right && right < array.length;
- //@ requires (\forall int i; 0 <= i && i < index.length; 0 <= index[i] && index[i] < array.length);
- //@ requires array != index;
- // assignable index;
- private static void quickSort(/*@non_null@*/ double[] array, /*@non_null@*/ int[] index,
- int left, int right) {
+ // @ requires 0 <= first && first <= right && right < array.length;
+ // @ requires (\forall int i; 0 <= i && i < index.length; 0 <= index[i] &&
+ // index[i] < array.length);
+ // @ requires array != index;
+ // assignable index;
+ private static void quickSort(/* @non_null@ */double[] array, /* @non_null@ */int[] index,
+ int left, int right) {
if (left < right) {
int middle = partition(array, index, left, right);
@@ -1616,22 +1708,26 @@
quickSort(array, index, middle + 1, right);
}
}
-
+
/**
- * Implements quicksort according to Manber's "Introduction to
- * Algorithms".
- *
- * @param array the array of integers to be sorted
- * @param index the index into the array of integers
- * @param left the first index of the subset to be sorted
- * @param right the last index of the subset to be sorted
+ * Implements quicksort according to Manber's "Introduction to Algorithms".
+ *
+ * @param array
+ * the array of integers to be sorted
+ * @param index
+ * the index into the array of integers
+ * @param left
+ * the first index of the subset to be sorted
+ * @param right
+ * the last index of the subset to be sorted
*/
- //@ requires 0 <= first && first <= right && right < array.length;
- //@ requires (\forall int i; 0 <= i && i < index.length; 0 <= index[i] && index[i] < array.length);
- //@ requires array != index;
- // assignable index;
- private static void quickSort(/*@non_null@*/ int[] array, /*@non_null@*/ int[] index,
- int left, int right) {
+ // @ requires 0 <= first && first <= right && right < array.length;
+ // @ requires (\forall int i; 0 <= i && i < index.length; 0 <= index[i] &&
+ // index[i] < array.length);
+ // @ requires array != index;
+ // assignable index;
+ private static void quickSort(/* @non_null@ */int[] array, /* @non_null@ */int[] index,
+ int left, int right) {
if (left < right) {
int middle = partition(array, index, left, right);
@@ -1639,23 +1735,28 @@
quickSort(array, index, middle + 1, right);
}
}
-
+
/**
- * Implements computation of the kth-smallest element according
- * to Manber's "Introduction to Algorithms".
- *
- * @param array the array of double
- * @param index the index into the array of doubles
- * @param left the first index of the subset
- * @param right the last index of the subset
- * @param k the value of k
- *
+ * Implements computation of the kth-smallest element according to Manber's
+ * "Introduction to Algorithms".
+ *
+ * @param array
+ * the array of double
+ * @param index
+ * the index into the array of doubles
+ * @param left
+ * the first index of the subset
+ * @param right
+ * the last index of the subset
+ * @param k
+ * the value of k
+ *
* @return the index of the kth-smallest element
*/
- //@ requires 0 <= first && first <= right && right < array.length;
- private static int select(/*@non_null@*/ double[] array, /*@non_null@*/ int[] index,
- int left, int right, int k) {
-
+ // @ requires 0 <= first && first <= right && right < array.length;
+ private static int select(/* @non_null@ */double[] array, /* @non_null@ */int[] index,
+ int left, int right, int k) {
+
if (left == right) {
return left;
} else {
@@ -1669,29 +1770,31 @@
}
/**
- * Converts a File's absolute path to a path relative to the user
- * (ie start) directory. Includes an additional workaround for Cygwin, which
- * doesn't like upper case drive letters.
- * @param absolute the File to convert to relative path
+ * Converts a File's absolute path to a path relative to the user (ie start)
+ * directory. Includes an additional workaround for Cygwin, which doesn't like
+ * upper case drive letters.
+ *
+ * @param absolute
+ * the File to convert to relative path
* @return a File with a path that is relative to the user's directory
- * @exception Exception if the path cannot be constructed
+ * @exception Exception
+ * if the path cannot be constructed
*/
public static File convertToRelativePath(File absolute) throws Exception {
- File result;
- String fileStr;
-
+ File result;
+ String fileStr;
+
result = null;
-
+
// if we're running windows, it could be Cygwin
if (File.separator.equals("\\")) {
// Cygwin doesn't like upper case drives -> try lower case drive
try {
fileStr = absolute.getPath();
- fileStr = fileStr.substring(0, 1).toLowerCase()
- + fileStr.substring(1);
+ fileStr = fileStr.substring(0, 1).toLowerCase()
+ + fileStr.substring(1);
result = createRelativePath(new File(fileStr));
- }
- catch (Exception e) {
+ } catch (Exception e) {
// no luck with Cygwin workaround, convert it like it is
result = createRelativePath(absolute);
}
@@ -1704,94 +1807,101 @@
}
/**
- * Converts a File's absolute path to a path relative to the user
- * (ie start) directory.
+ * Converts a File's absolute path to a path relative to the user (ie start)
+ * directory.
*
- * @param absolute the File to convert to relative path
+ * @param absolute
+ * the File to convert to relative path
* @return a File with a path that is relative to the user's directory
- * @exception Exception if the path cannot be constructed
+ * @exception Exception
+ * if the path cannot be constructed
*/
protected static File createRelativePath(File absolute) throws Exception {
File userDir = new File(System.getProperty("user.dir"));
String userPath = userDir.getAbsolutePath() + File.separator;
- String targetPath = (new File(absolute.getParent())).getPath()
- + File.separator;
+ String targetPath = (new File(absolute.getParent())).getPath()
+ + File.separator;
String fileName = absolute.getName();
StringBuffer relativePath = new StringBuffer();
- // relativePath.append("."+File.separator);
- // System.err.println("User dir "+userPath);
- // System.err.println("Target path "+targetPath);
-
+ // relativePath.append("."+File.separator);
+ // System.err.println("User dir "+userPath);
+ // System.err.println("Target path "+targetPath);
+
// file is in user dir (or subdir)
int subdir = targetPath.indexOf(userPath);
if (subdir == 0) {
if (userPath.length() == targetPath.length()) {
- relativePath.append(fileName);
+ relativePath.append(fileName);
} else {
- int ll = userPath.length();
- relativePath.append(targetPath.substring(ll));
- relativePath.append(fileName);
+ int ll = userPath.length();
+ relativePath.append(targetPath.substring(ll));
+ relativePath.append(fileName);
}
} else {
int sepCount = 0;
String temp = new String(userPath);
while (temp.indexOf(File.separator) != -1) {
- int ind = temp.indexOf(File.separator);
- sepCount++;
- temp = temp.substring(ind+1, temp.length());
+ int ind = temp.indexOf(File.separator);
+ sepCount++;
+ temp = temp.substring(ind + 1, temp.length());
}
-
+
String targetTemp = new String(targetPath);
String userTemp = new String(userPath);
int tcount = 0;
while (targetTemp.indexOf(File.separator) != -1) {
- int ind = targetTemp.indexOf(File.separator);
- int ind2 = userTemp.indexOf(File.separator);
- String tpart = targetTemp.substring(0,ind+1);
- String upart = userTemp.substring(0,ind2+1);
- if (tpart.compareTo(upart) != 0) {
- if (tcount == 0) {
- tcount = -1;
- }
- break;
- }
- tcount++;
- targetTemp = targetTemp.substring(ind+1, targetTemp.length());
- userTemp = userTemp.substring(ind2+1, userTemp.length());
+ int ind = targetTemp.indexOf(File.separator);
+ int ind2 = userTemp.indexOf(File.separator);
+ String tpart = targetTemp.substring(0, ind + 1);
+ String upart = userTemp.substring(0, ind2 + 1);
+ if (tpart.compareTo(upart) != 0) {
+ if (tcount == 0) {
+ tcount = -1;
+ }
+ break;
+ }
+ tcount++;
+ targetTemp = targetTemp.substring(ind + 1, targetTemp.length());
+ userTemp = userTemp.substring(ind2 + 1, userTemp.length());
}
if (tcount == -1) {
- // then target file is probably on another drive (under windows)
- throw new Exception("Can't construct a path to file relative to user "
- +"dir.");
+ // then target file is probably on another drive (under windows)
+ throw new Exception("Can't construct a path to file relative to user "
+ + "dir.");
}
if (targetTemp.indexOf(File.separator) == -1) {
- targetTemp = "";
+ targetTemp = "";
}
for (int i = 0; i < sepCount - tcount; i++) {
- relativePath.append(".."+File.separator);
+ relativePath.append(".." + File.separator);
}
relativePath.append(targetTemp + fileName);
}
- // System.err.println("new path : "+relativePath.toString());
+ // System.err.println("new path : "+relativePath.toString());
return new File(relativePath.toString());
}
-
+
/**
- * Implements computation of the kth-smallest element according
- * to Manber's "Introduction to Algorithms".
- *
- * @param array the array of integers
- * @param index the index into the array of integers
- * @param left the first index of the subset
- * @param right the last index of the subset
- * @param k the value of k
- *
+ * Implements computation of the kth-smallest element according to Manber's
+ * "Introduction to Algorithms".
+ *
+ * @param array
+ * the array of integers
+ * @param index
+ * the index into the array of integers
+ * @param left
+ * the first index of the subset
+ * @param right
+ * the last index of the subset
+ * @param k
+ * the value of k
+ *
* @return the index of the kth-smallest element
*/
- //@ requires 0 <= first && first <= right && right < array.length;
- private static int select(/*@non_null@*/ int[] array, /*@non_null@*/ int[] index,
- int left, int right, int k) {
-
+ // @ requires 0 <= first && first <= right && right < array.length;
+ private static int select(/* @non_null@ */int[] array, /* @non_null@ */int[] index,
+ int left, int right, int k) {
+
if (left == right) {
return left;
} else {
@@ -1803,67 +1913,68 @@
}
}
}
-
-
-
+
/**
* Breaks up the string, if wider than "columns" characters.
- *
- * @param s the string to process
- * @param columns the width in columns
- * @return the processed string
+ *
+ * @param s
+ * the string to process
+ * @param columns
+ * the width in columns
+ * @return the processed string
*/
public static String[] breakUp(String s, int columns) {
- Vector<String> result;
- String line;
- BreakIterator boundary;
- int boundaryStart;
- int boundaryEnd;
- String word;
- String punctuation;
- int i;
- String[] lines;
+ Vector<String> result;
+ String line;
+ BreakIterator boundary;
+ int boundaryStart;
+ int boundaryEnd;
+ String word;
+ String punctuation;
+ int i;
+ String[] lines;
- result = new Vector<String>();
+ result = new Vector<String>();
punctuation = " .,;:!?'\"";
- lines = s.split("\n");
+ lines = s.split("\n");
for (i = 0; i < lines.length; i++) {
- boundary = BreakIterator.getWordInstance();
+ boundary = BreakIterator.getWordInstance();
boundary.setText(lines[i]);
boundaryStart = boundary.first();
- boundaryEnd = boundary.next();
- line = "";
+ boundaryEnd = boundary.next();
+ line = "";
while (boundaryEnd != BreakIterator.DONE) {
- word = lines[i].substring(boundaryStart, boundaryEnd);
- if (line.length() >= columns) {
- if (word.length() == 1) {
- if (punctuation.indexOf(word.charAt(0)) > -1) {
- line += word;
- word = "";
- }
- }
- result.add(line);
- line = "";
- }
- line += word;
- boundaryStart = boundaryEnd;
- boundaryEnd = boundary.next();
+ word = lines[i].substring(boundaryStart, boundaryEnd);
+ if (line.length() >= columns) {
+ if (word.length() == 1) {
+ if (punctuation.indexOf(word.charAt(0)) > -1) {
+ line += word;
+ word = "";
+ }
+ }
+ result.add(line);
+ line = "";
+ }
+ line += word;
+ boundaryStart = boundaryEnd;
+ boundaryEnd = boundary.next();
}
if (line.length() > 0)
- result.add(line);
+ result.add(line);
}
return result.toArray(new String[result.size()]);
}
/**
- * Creates a new instance of an object given it's class name and
- * (optional) arguments to pass to it's setOptions method. If the
- * object implements OptionHandler and the options parameter is
- * non-null, the object will have it's options set. Example use:<p>
- *
+ * Creates a new instance of an object given it's class name and (optional)
+ * arguments to pass to it's setOptions method. If the object implements
+ * OptionHandler and the options parameter is non-null, the object will have
+ * it's options set. Example use:
+ * <p>
+ *
* <code> <pre>
* String classifierName = Utils.getOption('W', options);
* Classifier c = (Classifier)Utils.forName(Classifier.class,
@@ -1871,21 +1982,25 @@
* options);
* setClassifier(c);
* </pre></code>
- *
- * @param classType the class that the instantiated object should
- * be assignable to -- an exception is thrown if this is not the case
- * @param className the fully qualified class name of the object
- * @param options an array of options suitable for passing to setOptions. May
- * be null. Any options accepted by the object will be removed from the
- * array.
+ *
+ * @param classType
+ * the class that the instantiated object should be assignable to --
+ * an exception is thrown if this is not the case
+ * @param className
+ * the fully qualified class name of the object
+ * @param options
+ * an array of options suitable for passing to setOptions. May be
+ * null. Any options accepted by the object will be removed from the
+ * array.
* @return the newly created object, ready for use.
- * @exception Exception if the class name is invalid, or if the
- * class is not assignable to the desired class type, or the options
- * supplied are not acceptable to the object
+ * @exception Exception
+ * if the class name is invalid, or if the class is not
+ * assignable to the desired class type, or the options supplied
+ * are not acceptable to the object
*/
public static Object forName(Class<?> classType,
- String className,
- String[] options) throws Exception {
+ String className,
+ String[] options) throws Exception {
Class<?> c = null;
try {
@@ -1895,17 +2010,15 @@
}
if (!classType.isAssignableFrom(c)) {
throw new Exception(classType.getName() + " is not assignable from "
- + className);
+ + className);
}
Object o = c.newInstance();
- /*if ((o instanceof OptionHandler)
- && (options != null)) {
- ((OptionHandler)o).setOptions(options);
- Utils.checkForRemainingOptions(options);
- }*/
+ /*
+ * if ((o instanceof OptionHandler) && (options != null)) {
+ * ((OptionHandler)o).setOptions(options);
+ * Utils.checkForRemainingOptions(options); }
+ */
return o;
}
}
-
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningCurve.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningCurve.java
index e79dde8..907371c 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningCurve.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningCurve.java
@@ -30,103 +30,103 @@
/**
* Class that stores and keeps the history of evaluation measurements.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class LearningCurve extends AbstractMOAObject {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected List<String> measurementNames = new ArrayList<String>();
+ protected List<String> measurementNames = new ArrayList<String>();
- protected List<double[]> measurementValues = new ArrayList<double[]>();
+ protected List<double[]> measurementValues = new ArrayList<double[]>();
- public LearningCurve(String orderingMeasurementName) {
- this.measurementNames.add(orderingMeasurementName);
+ public LearningCurve(String orderingMeasurementName) {
+ this.measurementNames.add(orderingMeasurementName);
+ }
+
+ public String getOrderingMeasurementName() {
+ return this.measurementNames.get(0);
+ }
+
+ public void insertEntry(LearningEvaluation learningEvaluation) {
+ Measurement[] measurements = learningEvaluation.getMeasurements();
+ Measurement orderMeasurement = Measurement.getMeasurementNamed(
+ getOrderingMeasurementName(), measurements);
+ if (orderMeasurement == null) {
+ throw new IllegalArgumentException();
}
-
- public String getOrderingMeasurementName() {
- return this.measurementNames.get(0);
+ DoubleVector entryVals = new DoubleVector();
+ for (Measurement measurement : measurements) {
+ entryVals.setValue(addMeasurementName(measurement.getName()),
+ measurement.getValue());
}
-
- public void insertEntry(LearningEvaluation learningEvaluation) {
- Measurement[] measurements = learningEvaluation.getMeasurements();
- Measurement orderMeasurement = Measurement.getMeasurementNamed(
- getOrderingMeasurementName(), measurements);
- if (orderMeasurement == null) {
- throw new IllegalArgumentException();
- }
- DoubleVector entryVals = new DoubleVector();
- for (Measurement measurement : measurements) {
- entryVals.setValue(addMeasurementName(measurement.getName()),
- measurement.getValue());
- }
- double orderVal = orderMeasurement.getValue();
- int index = 0;
- while ((index < this.measurementValues.size())
- && (orderVal > this.measurementValues.get(index)[0])) {
- index++;
- }
- this.measurementValues.add(index, entryVals.getArrayRef());
+ double orderVal = orderMeasurement.getValue();
+ int index = 0;
+ while ((index < this.measurementValues.size())
+ && (orderVal > this.measurementValues.get(index)[0])) {
+ index++;
}
+ this.measurementValues.add(index, entryVals.getArrayRef());
+ }
- public int numEntries() {
- return this.measurementValues.size();
- }
+ public int numEntries() {
+ return this.measurementValues.size();
+ }
- protected int addMeasurementName(String name) {
- int index = this.measurementNames.indexOf(name);
- if (index < 0) {
- index = this.measurementNames.size();
- this.measurementNames.add(name);
- }
- return index;
+ protected int addMeasurementName(String name) {
+ int index = this.measurementNames.indexOf(name);
+ if (index < 0) {
+ index = this.measurementNames.size();
+ this.measurementNames.add(name);
}
+ return index;
+ }
- public String headerToString() {
- StringBuilder sb = new StringBuilder();
- boolean first = true;
- for (String name : this.measurementNames) {
- if (!first) {
- sb.append(',');
- } else {
- first = false;
- }
- sb.append(name);
- }
- return sb.toString();
+ public String headerToString() {
+ StringBuilder sb = new StringBuilder();
+ boolean first = true;
+ for (String name : this.measurementNames) {
+ if (!first) {
+ sb.append(',');
+ } else {
+ first = false;
+ }
+ sb.append(name);
}
+ return sb.toString();
+ }
- public String entryToString(int entryIndex) {
- StringBuilder sb = new StringBuilder();
- double[] vals = this.measurementValues.get(entryIndex);
- for (int i = 0; i < this.measurementNames.size(); i++) {
- if (i > 0) {
- sb.append(',');
- }
- if ((i >= vals.length) || Double.isNaN(vals[i])) {
- sb.append('?');
- } else {
- sb.append(Double.toString(vals[i]));
- }
- }
- return sb.toString();
+ public String entryToString(int entryIndex) {
+ StringBuilder sb = new StringBuilder();
+ double[] vals = this.measurementValues.get(entryIndex);
+ for (int i = 0; i < this.measurementNames.size(); i++) {
+ if (i > 0) {
+ sb.append(',');
+ }
+ if ((i >= vals.length) || Double.isNaN(vals[i])) {
+ sb.append('?');
+ } else {
+ sb.append(Double.toString(vals[i]));
+ }
}
+ return sb.toString();
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- sb.append(headerToString());
- for (int i = 0; i < numEntries(); i++) {
- StringUtils.appendNewlineIndented(sb, indent, entryToString(i));
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ sb.append(headerToString());
+ for (int i = 0; i < numEntries(); i++) {
+ StringUtils.appendNewlineIndented(sb, indent, entryToString(i));
}
+ }
- public double getMeasurement(int entryIndex, int measurementIndex) {
- return this.measurementValues.get(entryIndex)[measurementIndex];
- }
+ public double getMeasurement(int entryIndex, int measurementIndex) {
+ return this.measurementValues.get(entryIndex)[measurementIndex];
+ }
- public String getMeasurementName(int measurementIndex) {
- return this.measurementNames.get(measurementIndex);
- }
+ public String getMeasurementName(int measurementIndex) {
+ return this.measurementNames.get(measurementIndex);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningEvaluation.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningEvaluation.java
index b0aee36..67b978c 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningEvaluation.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningEvaluation.java
@@ -29,35 +29,35 @@
/**
* Class that stores an array of evaluation measurements.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class LearningEvaluation extends AbstractMOAObject {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- protected Measurement[] measurements;
+ protected Measurement[] measurements;
- public LearningEvaluation(Measurement[] measurements) {
- this.measurements = measurements.clone();
- }
+ public LearningEvaluation(Measurement[] measurements) {
+ this.measurements = measurements.clone();
+ }
- public LearningEvaluation(Measurement[] evaluationMeasurements,
- LearningPerformanceEvaluator cpe, Learner model) {
- List<Measurement> measurementList = new LinkedList<Measurement>();
- measurementList.addAll(Arrays.asList(evaluationMeasurements));
- measurementList.addAll(Arrays.asList(cpe.getPerformanceMeasurements()));
- measurementList.addAll(Arrays.asList(model.getModelMeasurements()));
- this.measurements = measurementList.toArray(new Measurement[measurementList.size()]);
- }
+ public LearningEvaluation(Measurement[] evaluationMeasurements,
+ LearningPerformanceEvaluator cpe, Learner model) {
+ List<Measurement> measurementList = new LinkedList<Measurement>();
+ measurementList.addAll(Arrays.asList(evaluationMeasurements));
+ measurementList.addAll(Arrays.asList(cpe.getPerformanceMeasurements()));
+ measurementList.addAll(Arrays.asList(model.getModelMeasurements()));
+ this.measurements = measurementList.toArray(new Measurement[measurementList.size()]);
+ }
- public Measurement[] getMeasurements() {
- return this.measurements.clone();
- }
+ public Measurement[] getMeasurements() {
+ return this.measurements.clone();
+ }
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- Measurement.getMeasurementsDescription(this.measurements, sb, indent);
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ Measurement.getMeasurementsDescription(this.measurements, sb, indent);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningPerformanceEvaluator.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningPerformanceEvaluator.java
index 86b571d..b236066 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningPerformanceEvaluator.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/LearningPerformanceEvaluator.java
@@ -20,42 +20,43 @@
* #L%
*/
-
import com.yahoo.labs.samoa.moa.MOAObject;
import com.yahoo.labs.samoa.moa.core.Example;
import com.yahoo.labs.samoa.moa.core.Measurement;
/**
- * Interface implemented by learner evaluators to monitor
- * the results of the learning process.
- *
+ * Interface implemented by learner evaluators to monitor the results of the
+ * learning process.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public interface LearningPerformanceEvaluator<E extends Example> extends MOAObject {
- /**
- * Resets this evaluator. It must be similar to
- * starting a new evaluator from scratch.
- *
- */
- public void reset();
+ /**
+ * Resets this evaluator. It must be similar to starting a new evaluator from
+ * scratch.
+ *
+ */
+ public void reset();
- /**
- * Adds a learning result to this evaluator.
- *
- * @param example the example to be classified
- * @param classVotes an array containing the estimated membership
- * probabilities of the test instance in each class
- * @return an array of measurements monitored in this evaluator
- */
- public void addResult(E example, double[] classVotes);
+ /**
+ * Adds a learning result to this evaluator.
+ *
+ * @param example
+ * the example to be classified
+ * @param classVotes
+ * an array containing the estimated membership probabilities of the
+ * test instance in each class
+ * @return an array of measurements monitored in this evaluator
+ */
+ public void addResult(E example, double[] classVotes);
- /**
- * Gets the current measurements monitored by this evaluator.
- *
- * @return an array of measurements monitored by this evaluator
- */
- public Measurement[] getPerformanceMeasurements();
+ /**
+ * Gets the current measurements monitored by this evaluator.
+ *
+ * @return an array of measurements monitored by this evaluator
+ */
+ public Measurement[] getPerformanceMeasurements();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/MeasureCollection.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/MeasureCollection.java
index 898aca3..9fe1fc3 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/MeasureCollection.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/MeasureCollection.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.evaluation;
/*
@@ -27,233 +26,236 @@
import com.yahoo.labs.samoa.moa.cluster.Clustering;
import com.yahoo.labs.samoa.moa.core.DataPoint;
-public abstract class MeasureCollection extends AbstractMOAObject{
- private String[] names;
- private ArrayList<Double>[] values;
- private ArrayList<Double>[] sortedValues;
- private ArrayList<String> events;
-
- private double[] minValue;
- private double[] maxValue;
- private double[] sumValues;
- private boolean[] enabled;
- private boolean[] corrupted;
- private double time;
- private boolean debug = true;
- private MembershipMatrix mm = null;
+public abstract class MeasureCollection extends AbstractMOAObject {
+ private String[] names;
+ private ArrayList<Double>[] values;
+ private ArrayList<Double>[] sortedValues;
+ private ArrayList<String> events;
- private HashMap<String, Integer> map;
+ private double[] minValue;
+ private double[] maxValue;
+ private double[] sumValues;
+ private boolean[] enabled;
+ private boolean[] corrupted;
+ private double time;
+ private boolean debug = true;
+ private MembershipMatrix mm = null;
- private int numMeasures = 0;
-
-
+ private HashMap<String, Integer> map;
+ private int numMeasures = 0;
- public MeasureCollection() {
- names = getNames();
- numMeasures = names.length;
- map = new HashMap<String, Integer>(numMeasures);
- for (int i = 0; i < names.length; i++) {
- map.put(names[i],i);
- }
- values = (ArrayList<Double>[]) new ArrayList[numMeasures];
- sortedValues = (ArrayList<Double>[]) new ArrayList[numMeasures];
- maxValue = new double[numMeasures];
- minValue = new double[numMeasures];
- sumValues = new double[numMeasures];
- corrupted = new boolean[numMeasures];
- enabled = getDefaultEnabled();
- time = 0;
- events = new ArrayList<String>();
+ public MeasureCollection() {
+ names = getNames();
+ numMeasures = names.length;
+ map = new HashMap<String, Integer>(numMeasures);
+ for (int i = 0; i < names.length; i++) {
+ map.put(names[i], i);
+ }
+ values = (ArrayList<Double>[]) new ArrayList[numMeasures];
+ sortedValues = (ArrayList<Double>[]) new ArrayList[numMeasures];
+ maxValue = new double[numMeasures];
+ minValue = new double[numMeasures];
+ sumValues = new double[numMeasures];
+ corrupted = new boolean[numMeasures];
+ enabled = getDefaultEnabled();
+ time = 0;
+ events = new ArrayList<String>();
- for (int i = 0; i < numMeasures; i++) {
- values[i] = new ArrayList<Double>();
- sortedValues[i] = new ArrayList<Double>();
- maxValue[i] = Double.MIN_VALUE;
- minValue[i] = Double.MAX_VALUE;
- corrupted[i] = false;
- sumValues[i] = 0.0;
- }
-
+ for (int i = 0; i < numMeasures; i++) {
+ values[i] = new ArrayList<Double>();
+ sortedValues[i] = new ArrayList<Double>();
+ maxValue[i] = Double.MIN_VALUE;
+ minValue[i] = Double.MAX_VALUE;
+ corrupted[i] = false;
+ sumValues[i] = 0.0;
}
- protected abstract String[] getNames();
+ }
- public void addValue(int index, double value){
- if(Double.isNaN(value)){
- if(debug)
- System.out.println("NaN for "+names[index]);
- corrupted[index] = true;
- }
- if(value < 0){
- if(debug)
- System.out.println("Negative value for "+names[index]);
- }
+ protected abstract String[] getNames();
- values[index].add(value);
- sumValues[index]+=value;
- if(value < minValue[index]) minValue[index] = value;
- if(value > maxValue[index]) maxValue[index] = value;
- }
-
- protected void addValue(String name, double value){
- if(map.containsKey(name)){
- addValue(map.get(name),value);
- }
- else{
- System.out.println(name+" is not a valid measure key, no value added");
- }
- }
-
- //add an empty entry e.g. if evaluation crashed internally
- public void addEmptyValue(int index){
- values[index].add(Double.NaN);
- corrupted[index] = true;
- }
-
- public int getNumMeasures(){
- return numMeasures;
- }
-
- public String getName(int index){
- return names[index];
- }
-
- public double getMaxValue(int index){
- return maxValue[index];
- }
-
- public double getMinValue(int index){
- return minValue[index];
- }
-
- public double getLastValue(int index){
- if(values[index].size()<1) return Double.NaN;
- return values[index].get(values[index].size()-1);
- }
-
- public double getMean(int index){
- if(corrupted[index] || values[index].size()<1)
- return Double.NaN;
-
- return sumValues[index]/values[index].size();
- }
-
- private void updateSortedValues(int index){
- //naive implementation of insertion sort
- for (int i = sortedValues[index].size(); i < values[index].size(); i++) {
- double v = values[index].get(i);
- int insertIndex = 0;
- while(!sortedValues[index].isEmpty() && insertIndex < sortedValues[index].size() && v > sortedValues[index].get(insertIndex))
- insertIndex++;
- sortedValues[index].add(insertIndex,v);
- }
-// for (int i = 0; i < sortedValues[index].size(); i++) {
-// System.out.print(sortedValues[index].get(i)+" ");
-// }
-// System.out.println();
- }
-
- public void clean(int index){
- sortedValues[index].clear();
- }
-
- public double getMedian(int index){
- updateSortedValues(index);
- int size = sortedValues[index].size();
-
- if(size > 0){
- if(size%2 == 1)
- return sortedValues[index].get((int)(size/2));
- else
- return (sortedValues[index].get((size-1)/2)+sortedValues[index].get((size-1)/2+1))/2.0;
- }
- return Double.NaN;
+ public void addValue(int index, double value) {
+ if (Double.isNaN(value)) {
+ if (debug)
+ System.out.println("NaN for " + names[index]);
+ corrupted[index] = true;
+ }
+ if (value < 0) {
+ if (debug)
+ System.out.println("Negative value for " + names[index]);
}
- public double getLowerQuartile(int index){
- updateSortedValues(index);
- int size = sortedValues[index].size();
- if(size > 11){
- return sortedValues[index].get(Math.round(size*0.25f));
- }
- return Double.NaN;
- }
+ values[index].add(value);
+ sumValues[index] += value;
+ if (value < minValue[index])
+ minValue[index] = value;
+ if (value > maxValue[index])
+ maxValue[index] = value;
+ }
- public double getUpperQuartile(int index){
- updateSortedValues(index);
- int size = sortedValues[index].size();
- if(size > 11){
- return sortedValues[index].get(Math.round(size*0.75f-1));
- }
- return Double.NaN;
- }
-
-
- public int getNumberOfValues(int index){
- return values[index].size();
- }
-
- public double getValue(int index, int i){
- if(i>=values[index].size()) return Double.NaN;
- return values[index].get(i);
- }
-
- public ArrayList<Double> getAllValues(int index){
- return values[index];
- }
-
- public void setEnabled(int index, boolean value){
- enabled[index] = value;
- }
-
- public boolean isEnabled(int index){
- return enabled[index];
- }
-
- public double getMeanRunningTime(){
- if(values[0].size()!=0)
- return (time/10e5/values[0].size());
- else
- return 0;
- }
-
- protected boolean[] getDefaultEnabled(){
- boolean[] defaults = new boolean[numMeasures];
- for (int i = 0; i < defaults.length; i++) {
- defaults[i] = true;
- }
- return defaults;
- }
-
- protected abstract void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) throws Exception;
-
- /*
- * Evaluate Clustering
- *
- * return Time in milliseconds
- */
- public double evaluateClusteringPerformance(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) throws Exception{
- long start = System.nanoTime();
- evaluateClustering(clustering, trueClustering, points);
- long duration = System.nanoTime()-start;
- time+=duration;
- duration/=10e5;
- return duration;
- }
-
- public void getDescription(StringBuilder sb, int indent) {
-
+ protected void addValue(String name, double value) {
+ if (map.containsKey(name)) {
+ addValue(map.get(name), value);
}
+ else {
+ System.out.println(name + " is not a valid measure key, no value added");
+ }
+ }
- public void addEventType(String type){
- events.add(type);
- }
- public String getEventType(int index){
- if(index < events.size())
- return events.get(index);
- else
- return null;
- }
+ // add an empty entry e.g. if evaluation crashed internally
+ public void addEmptyValue(int index) {
+ values[index].add(Double.NaN);
+ corrupted[index] = true;
+ }
+
+ public int getNumMeasures() {
+ return numMeasures;
+ }
+
+ public String getName(int index) {
+ return names[index];
+ }
+
+ public double getMaxValue(int index) {
+ return maxValue[index];
+ }
+
+ public double getMinValue(int index) {
+ return minValue[index];
+ }
+
+ public double getLastValue(int index) {
+ if (values[index].size() < 1)
+ return Double.NaN;
+ return values[index].get(values[index].size() - 1);
+ }
+
+ public double getMean(int index) {
+ if (corrupted[index] || values[index].size() < 1)
+ return Double.NaN;
+
+ return sumValues[index] / values[index].size();
+ }
+
+ private void updateSortedValues(int index) {
+ // naive implementation of insertion sort
+ for (int i = sortedValues[index].size(); i < values[index].size(); i++) {
+ double v = values[index].get(i);
+ int insertIndex = 0;
+ while (!sortedValues[index].isEmpty() && insertIndex < sortedValues[index].size()
+ && v > sortedValues[index].get(insertIndex))
+ insertIndex++;
+ sortedValues[index].add(insertIndex, v);
+ }
+ // for (int i = 0; i < sortedValues[index].size(); i++) {
+ // System.out.print(sortedValues[index].get(i)+" ");
+ // }
+ // System.out.println();
+ }
+
+ public void clean(int index) {
+ sortedValues[index].clear();
+ }
+
+ public double getMedian(int index) {
+ updateSortedValues(index);
+ int size = sortedValues[index].size();
+
+ if (size > 0) {
+ if (size % 2 == 1)
+ return sortedValues[index].get((int) (size / 2));
+ else
+ return (sortedValues[index].get((size - 1) / 2) + sortedValues[index].get((size - 1) / 2 + 1)) / 2.0;
+ }
+ return Double.NaN;
+ }
+
+ public double getLowerQuartile(int index) {
+ updateSortedValues(index);
+ int size = sortedValues[index].size();
+ if (size > 11) {
+ return sortedValues[index].get(Math.round(size * 0.25f));
+ }
+ return Double.NaN;
+ }
+
+ public double getUpperQuartile(int index) {
+ updateSortedValues(index);
+ int size = sortedValues[index].size();
+ if (size > 11) {
+ return sortedValues[index].get(Math.round(size * 0.75f - 1));
+ }
+ return Double.NaN;
+ }
+
+ public int getNumberOfValues(int index) {
+ return values[index].size();
+ }
+
+ public double getValue(int index, int i) {
+ if (i >= values[index].size())
+ return Double.NaN;
+ return values[index].get(i);
+ }
+
+ public ArrayList<Double> getAllValues(int index) {
+ return values[index];
+ }
+
+ public void setEnabled(int index, boolean value) {
+ enabled[index] = value;
+ }
+
+ public boolean isEnabled(int index) {
+ return enabled[index];
+ }
+
+ public double getMeanRunningTime() {
+ if (values[0].size() != 0)
+ return (time / 10e5 / values[0].size());
+ else
+ return 0;
+ }
+
+ protected boolean[] getDefaultEnabled() {
+ boolean[] defaults = new boolean[numMeasures];
+ for (int i = 0; i < defaults.length; i++) {
+ defaults[i] = true;
+ }
+ return defaults;
+ }
+
+ protected abstract void evaluateClustering(Clustering clustering, Clustering trueClustering,
+ ArrayList<DataPoint> points) throws Exception;
+
+ /*
+ * Evaluate Clustering
+ *
+ * return Time in milliseconds
+ */
+ public double evaluateClusteringPerformance(Clustering clustering, Clustering trueClustering,
+ ArrayList<DataPoint> points) throws Exception {
+ long start = System.nanoTime();
+ evaluateClustering(clustering, trueClustering, points);
+ long duration = System.nanoTime() - start;
+ time += duration;
+ duration /= 10e5;
+ return duration;
+ }
+
+ public void getDescription(StringBuilder sb, int indent) {
+
+ }
+
+ public void addEventType(String type) {
+ events.add(type);
+ }
+
+ public String getEventType(int index) {
+ if (index < events.size())
+ return events.get(index);
+ else
+ return null;
+ }
}
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/MembershipMatrix.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/MembershipMatrix.java
index 14227ad..c14c145 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/MembershipMatrix.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/evaluation/MembershipMatrix.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.evaluation;
/*
@@ -26,128 +25,124 @@
import com.yahoo.labs.samoa.moa.cluster.Clustering;
import com.yahoo.labs.samoa.moa.core.DataPoint;
-
public class MembershipMatrix {
- HashMap<Integer, Integer> classmap;
- int cluster_class_weights[][];
- int cluster_sums[];
- int class_sums[];
- int total_entries;
- int class_distribution[];
- int total_class_entries;
- int initalBuildTimestamp = -1;
+ HashMap<Integer, Integer> classmap;
+ int cluster_class_weights[][];
+ int cluster_sums[];
+ int class_sums[];
+ int total_entries;
+ int class_distribution[];
+ int total_class_entries;
+ int initalBuildTimestamp = -1;
- public MembershipMatrix(Clustering foundClustering, ArrayList<DataPoint> points) {
- classmap = Clustering.classValues(points);
-// int lastID = classmap.size()-1;
-// classmap.put(-1, lastID);
- int numClasses = classmap.size();
- int numCluster = foundClustering.size()+1;
+ public MembershipMatrix(Clustering foundClustering, ArrayList<DataPoint> points) {
+ classmap = Clustering.classValues(points);
+ // int lastID = classmap.size()-1;
+ // classmap.put(-1, lastID);
+ int numClasses = classmap.size();
+ int numCluster = foundClustering.size() + 1;
- cluster_class_weights = new int[numCluster][numClasses];
- class_distribution = new int[numClasses];
- cluster_sums = new int[numCluster];
- class_sums = new int[numClasses];
- total_entries = 0;
- total_class_entries = points.size();
- for (int p = 0; p < points.size(); p++) {
- int worklabel = classmap.get((int)points.get(p).classValue());
- //real class distribution
- class_distribution[worklabel]++;
- boolean covered = false;
- for (int c = 0; c < numCluster-1; c++) {
- double prob = foundClustering.get(c).getInclusionProbability(points.get(p));
- if(prob >= 1){
- cluster_class_weights[c][worklabel]++;
- class_sums[worklabel]++;
- cluster_sums[c]++;
- total_entries++;
- covered = true;
- }
- }
- if(!covered){
- cluster_class_weights[numCluster-1][worklabel]++;
- class_sums[worklabel]++;
- cluster_sums[numCluster-1]++;
- total_entries++;
- }
-
+ cluster_class_weights = new int[numCluster][numClasses];
+ class_distribution = new int[numClasses];
+ cluster_sums = new int[numCluster];
+ class_sums = new int[numClasses];
+ total_entries = 0;
+ total_class_entries = points.size();
+ for (int p = 0; p < points.size(); p++) {
+ int worklabel = classmap.get((int) points.get(p).classValue());
+ // real class distribution
+ class_distribution[worklabel]++;
+ boolean covered = false;
+ for (int c = 0; c < numCluster - 1; c++) {
+ double prob = foundClustering.get(c).getInclusionProbability(points.get(p));
+ if (prob >= 1) {
+ cluster_class_weights[c][worklabel]++;
+ class_sums[worklabel]++;
+ cluster_sums[c]++;
+ total_entries++;
+ covered = true;
}
-
- initalBuildTimestamp = points.get(0).getTimestamp();
+ }
+ if (!covered) {
+ cluster_class_weights[numCluster - 1][worklabel]++;
+ class_sums[worklabel]++;
+ cluster_sums[numCluster - 1]++;
+ total_entries++;
+ }
+
}
- public int getClusterClassWeight(int i, int j){
- return cluster_class_weights[i][j];
+ initalBuildTimestamp = points.get(0).getTimestamp();
+ }
+
+ public int getClusterClassWeight(int i, int j) {
+ return cluster_class_weights[i][j];
+ }
+
+ public int getClusterSum(int i) {
+ return cluster_sums[i];
+ }
+
+ public int getClassSum(int j) {
+ return class_sums[j];
+ }
+
+ public int getClassDistribution(int j) {
+ return class_distribution[j];
+ }
+
+ public int getClusterClassWeightByLabel(int cluster, int classLabel) {
+ return cluster_class_weights[cluster][classmap.get(classLabel)];
+ }
+
+ public int getClassSumByLabel(int classLabel) {
+ return class_sums[classmap.get(classLabel)];
+ }
+
+ public int getClassDistributionByLabel(int classLabel) {
+ return class_distribution[classmap.get(classLabel)];
+ }
+
+ public int getTotalEntries() {
+ return total_entries;
+ }
+
+ public int getNumClasses() {
+ return classmap.size();
+ }
+
+ public boolean hasNoiseClass() {
+ return classmap.containsKey(-1);
+ }
+
+ @Override
+ public String toString() {
+ StringBuffer sb = new StringBuffer();
+ sb.append("Membership Matrix\n");
+ for (int i = 0; i < cluster_class_weights.length; i++) {
+ for (int j = 0; j < cluster_class_weights[i].length; j++) {
+ sb.append(cluster_class_weights[i][j] + "\t ");
+ }
+ sb.append("| " + cluster_sums[i] + "\n");
}
-
- public int getClusterSum(int i){
- return cluster_sums[i];
+ // sb.append("-----------\n");
+ for (int i = 0; i < class_sums.length; i++) {
+ sb.append(class_sums[i] + "\t ");
}
+ sb.append("| " + total_entries + "\n");
- public int getClassSum(int j){
- return class_sums[j];
+ sb.append("Real class distribution \n");
+ for (int i = 0; i < class_distribution.length; i++) {
+ sb.append(class_distribution[i] + "\t ");
}
+ sb.append("| " + total_class_entries + "\n");
- public int getClassDistribution(int j){
- return class_distribution[j];
- }
+ return sb.toString();
+ }
- public int getClusterClassWeightByLabel(int cluster, int classLabel){
- return cluster_class_weights[cluster][classmap.get(classLabel)];
- }
-
- public int getClassSumByLabel(int classLabel){
- return class_sums[classmap.get(classLabel)];
- }
-
- public int getClassDistributionByLabel(int classLabel){
- return class_distribution[classmap.get(classLabel)];
- }
-
- public int getTotalEntries(){
- return total_entries;
- }
-
- public int getNumClasses(){
- return classmap.size();
- }
-
- public boolean hasNoiseClass(){
- return classmap.containsKey(-1);
- }
-
- @Override
- public String toString() {
- StringBuffer sb = new StringBuffer();
- sb.append("Membership Matrix\n");
- for (int i = 0; i < cluster_class_weights.length; i++) {
- for (int j = 0; j < cluster_class_weights[i].length; j++) {
- sb.append(cluster_class_weights[i][j]+"\t ");
- }
- sb.append("| "+cluster_sums[i]+"\n");
- }
- //sb.append("-----------\n");
- for (int i = 0; i < class_sums.length; i++) {
- sb.append(class_sums[i]+"\t ");
- }
- sb.append("| "+total_entries+"\n");
-
-
- sb.append("Real class distribution \n");
- for (int i = 0; i < class_distribution.length; i++) {
- sb.append(class_distribution[i]+"\t ");
- }
- sb.append("| "+total_class_entries+"\n");
-
- return sb.toString();
- }
-
-
- public int getInitalBuildTimestamp(){
- return initalBuildTimestamp;
- }
-
+ public int getInitalBuildTimestamp() {
+ return initalBuildTimestamp;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/learners/Learner.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/learners/Learner.java
index 15de94b..be63e30 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/learners/Learner.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/learners/Learner.java
@@ -27,110 +27,109 @@
import com.yahoo.labs.samoa.moa.options.OptionHandler;
/**
- * Learner interface for incremental learning models.
- *
+ * Learner interface for incremental learning models.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public interface Learner<E extends Example> extends MOAObject, OptionHandler {
+ /**
+ * Gets whether this learner needs a random seed. Examples of methods that
+ * needs a random seed are bagging and boosting.
+ *
+ * @return true if the learner needs a random seed.
+ */
+ public boolean isRandomizable();
- /**
- * Gets whether this learner needs a random seed.
- * Examples of methods that needs a random seed are bagging and boosting.
- *
- * @return true if the learner needs a random seed.
- */
- public boolean isRandomizable();
+ /**
+ * Sets the seed for random number generation.
+ *
+ * @param s
+ * the seed
+ */
+ public void setRandomSeed(int s);
- /**
- * Sets the seed for random number generation.
- *
- * @param s the seed
- */
- public void setRandomSeed(int s);
+ /**
+ * Gets whether training has started.
+ *
+ * @return true if training has started
+ */
+ public boolean trainingHasStarted();
- /**
- * Gets whether training has started.
- *
- * @return true if training has started
- */
- public boolean trainingHasStarted();
+ /**
+ * Gets the sum of the weights of the instances that have been used by this
+ * learner during the training in <code>trainOnInstance</code>
+ *
+ * @return the weight of the instances that have been used training
+ */
+ public double trainingWeightSeenByModel();
- /**
- * Gets the sum of the weights of the instances that have been used
- * by this learner during the training in <code>trainOnInstance</code>
- *
- * @return the weight of the instances that have been used training
- */
- public double trainingWeightSeenByModel();
+ /**
+ * Resets this learner. It must be similar to starting a new learner from
+ * scratch.
+ *
+ */
+ public void resetLearning();
- /**
- * Resets this learner. It must be similar to
- * starting a new learner from scratch.
- *
- */
- public void resetLearning();
+ /**
+ * Trains this learner incrementally using the given example.
+ *
+ * @param inst
+ * the instance to be used for training
+ */
+ public void trainOnInstance(E example);
- /**
- * Trains this learner incrementally using the given example.
- *
- * @param inst the instance to be used for training
- */
- public void trainOnInstance(E example);
+ /**
+ * Predicts the class memberships for a given instance. If an instance is
+ * unclassified, the returned array elements must be all zero.
+ *
+ * @param inst
+ * the instance to be classified
+ * @return an array containing the estimated membership probabilities of the
+ * test instance in each class
+ */
+ public double[] getVotesForInstance(E example);
- /**
- * Predicts the class memberships for a given instance. If
- * an instance is unclassified, the returned array elements
- * must be all zero.
- *
- * @param inst the instance to be classified
- * @return an array containing the estimated membership
- * probabilities of the test instance in each class
- */
- public double[] getVotesForInstance(E example);
+ /**
+ * Gets the current measurements of this learner.
+ *
+ * @return an array of measurements to be used in evaluation tasks
+ */
+ public Measurement[] getModelMeasurements();
- /**
- * Gets the current measurements of this learner.
- *
- * @return an array of measurements to be used in evaluation tasks
- */
- public Measurement[] getModelMeasurements();
+ /**
+ * Gets the learners of this ensemble. Returns null if this learner is a
+ * single learner.
+ *
+ * @return an array of the learners of the ensemble
+ */
+ public Learner[] getSublearners();
- /**
- * Gets the learners of this ensemble.
- * Returns null if this learner is a single learner.
- *
- * @return an array of the learners of the ensemble
- */
- public Learner[] getSublearners();
+ /**
+ * Gets the model if this learner.
+ *
+ * @return the copy of this learner
+ */
+ public MOAObject getModel();
- /**
- * Gets the model if this learner.
- *
- * @return the copy of this learner
- */
- public MOAObject getModel();
-
- /**
- * Sets the reference to the header of the data stream.
- * The header of the data stream is extended from WEKA <code>Instances</code>.
- * This header is needed to know the number of classes and attributes
- *
- * @param ih the reference to the data stream header
- */
- public void setModelContext(InstancesHeader ih);
-
- /**
- * Gets the reference to the header of the data stream.
- * The header of the data stream is extended from WEKA <code>Instances</code>.
- * This header is needed to know the number of classes and attributes
- *
- * @return the reference to the data stream header
- */
- public InstancesHeader getModelContext();
-
+ /**
+ * Sets the reference to the header of the data stream. The header of the data
+ * stream is extended from WEKA <code>Instances</code>. This header is needed
+ * to know the number of classes and attributes
+ *
+ * @param ih
+ * the reference to the data stream header
+ */
+ public void setModelContext(InstancesHeader ih);
+
+ /**
+ * Gets the reference to the header of the data stream. The header of the data
+ * stream is extended from WEKA <code>Instances</code>. This header is needed
+ * to know the number of classes and attributes
+ *
+ * @return the reference to the data stream header
+ */
+ public InstancesHeader getModelContext();
+
}
-
-
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/AbstractClassOption.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/AbstractClassOption.java
index 1dd66c1..0685855 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/AbstractClassOption.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/AbstractClassOption.java
@@ -29,208 +29,226 @@
/**
* Abstract class option.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision$
*/
public abstract class AbstractClassOption extends AbstractOption {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- /** The prefix text to use to indicate file. */
- public static final String FILE_PREFIX_STRING = "file:";
+ /** The prefix text to use to indicate file. */
+ public static final String FILE_PREFIX_STRING = "file:";
- /** The prefix text to use to indicate inmem. */
- public static final String INMEM_PREFIX_STRING = "inmem:";
+ /** The prefix text to use to indicate inmem. */
+ public static final String INMEM_PREFIX_STRING = "inmem:";
- /** The current object */
- protected Object currentValue;
+ /** The current object */
+ protected Object currentValue;
- /** The class type */
- protected Class<?> requiredType;
+ /** The class type */
+ protected Class<?> requiredType;
- /** The default command line interface text. */
- protected String defaultCLIString;
+ /** The default command line interface text. */
+ protected String defaultCLIString;
- /** The null text. */
- protected String nullString;
+ /** The null text. */
+ protected String nullString;
- /**
- * Creates a new instance of an abstract option given its class name,
- * command line interface text, its purpose, its class type and its default
- * command line interface text.
- *
- * @param name the name of this option
- * @param cliChar the command line interface text
- * @param purpose the text describing the purpose of this option
- * @param requiredType the class type
- * @param defaultCLIString the default command line interface text
- */
- public AbstractClassOption(String name, char cliChar, String purpose,
- Class<?> requiredType, String defaultCLIString) {
- this(name, cliChar, purpose, requiredType, defaultCLIString, null);
+ /**
+ * Creates a new instance of an abstract option given its class name, command
+ * line interface text, its purpose, its class type and its default command
+ * line interface text.
+ *
+ * @param name
+ * the name of this option
+ * @param cliChar
+ * the command line interface text
+ * @param purpose
+ * the text describing the purpose of this option
+ * @param requiredType
+ * the class type
+ * @param defaultCLIString
+ * the default command line interface text
+ */
+ public AbstractClassOption(String name, char cliChar, String purpose,
+ Class<?> requiredType, String defaultCLIString) {
+ this(name, cliChar, purpose, requiredType, defaultCLIString, null);
+ }
+
+ /**
+ * Creates a new instance of an abstract option given its class name, command
+ * line interface text, its purpose, its class type, default command line
+ * interface text, and its null text.
+ *
+ * @param name
+ * the name of this option
+ * @param cliChar
+ * the command line interface text
+ * @param purpose
+ * the text describing the purpose of this option
+ * @param requiredType
+ * the class type
+ * @param defaultCLIString
+ * the default command line interface text
+ * @param nullString
+ * the null text
+ */
+ public AbstractClassOption(String name, char cliChar, String purpose,
+ Class<?> requiredType, String defaultCLIString, String nullString) {
+ super(name, cliChar, purpose);
+ this.requiredType = requiredType;
+ this.defaultCLIString = defaultCLIString;
+ this.nullString = nullString;
+ resetToDefault();
+ }
+
+ /**
+ * Sets current object.
+ *
+ * @param obj
+ * the object to set as current.
+ */
+ public void setCurrentObject(Object obj) {
+ if (((obj == null) && (this.nullString != null))
+ || this.requiredType.isInstance(obj)
+ || (obj instanceof String)
+ || (obj instanceof File)
+ || ((obj instanceof Task) && this.requiredType.isAssignableFrom(((Task) obj).getTaskResultType()))) {
+ this.currentValue = obj;
+ } else {
+ throw new IllegalArgumentException("Object not of required type.");
}
+ }
- /**
- * Creates a new instance of an abstract option given its class name,
- * command line interface text, its purpose, its class type, default
- * command line interface text, and its null text.
- *
- * @param name the name of this option
- * @param cliChar the command line interface text
- * @param purpose the text describing the purpose of this option
- * @param requiredType the class type
- * @param defaultCLIString the default command line interface text
- * @param nullString the null text
- */
- public AbstractClassOption(String name, char cliChar, String purpose,
- Class<?> requiredType, String defaultCLIString, String nullString) {
- super(name, cliChar, purpose);
- this.requiredType = requiredType;
- this.defaultCLIString = defaultCLIString;
- this.nullString = nullString;
- resetToDefault();
- }
+ /**
+ * Returns the current object.
+ *
+ * @return the current object
+ */
+ public Object getPreMaterializedObject() {
+ return this.currentValue;
+ }
- /**
- * Sets current object.
- *
- * @param obj the object to set as current.
- */
- public void setCurrentObject(Object obj) {
- if (((obj == null) && (this.nullString != null))
- || this.requiredType.isInstance(obj)
- || (obj instanceof String)
- || (obj instanceof File)
- || ((obj instanceof Task) && this.requiredType.isAssignableFrom(((Task) obj).getTaskResultType()))) {
- this.currentValue = obj;
- } else {
- throw new IllegalArgumentException("Object not of required type.");
+ /**
+ * Gets the class type of this option.
+ *
+ * @return the class type of this option
+ */
+ public Class<?> getRequiredType() {
+ return this.requiredType;
+ }
+
+ /**
+ * Gets the null string of this option.
+ *
+ * @return the null string of this option
+ */
+ public String getNullString() {
+ return this.nullString;
+ }
+
+ /**
+ * Gets a materialized object of this option.
+ *
+ * @param monitor
+ * the task monitor to use
+ * @param repository
+ * the object repository to use
+ * @return the materialized object
+ */
+ public Object materializeObject(TaskMonitor monitor,
+ ObjectRepository repository) {
+ if ((this.currentValue == null)
+ || this.requiredType.isInstance(this.currentValue)) {
+ return this.currentValue;
+ } else if (this.currentValue instanceof String) {
+ if (repository != null) {
+ Object inmemObj = repository.getObjectNamed((String) this.currentValue);
+ if (inmemObj == null) {
+ throw new RuntimeException("No object named "
+ + this.currentValue + " found in repository.");
}
+ return inmemObj;
+ }
+ throw new RuntimeException("No object repository available.");
+ } else if (this.currentValue instanceof Task) {
+ Task task = (Task) this.currentValue;
+ Object result = task.doTask(monitor, repository);
+ return result;
+ } else if (this.currentValue instanceof File) {
+ File inputFile = (File) this.currentValue;
+ Object result = null;
+ try {
+ result = SerializeUtils.readFromFile(inputFile);
+ } catch (Exception ex) {
+ throw new RuntimeException("Problem loading "
+ + this.requiredType.getName() + " object from file '"
+ + inputFile.getName() + "':\n" + ex.getMessage(), ex);
+ }
+ return result;
+ } else {
+ throw new RuntimeException(
+ "Could not materialize object of required type "
+ + this.requiredType.getName() + ", found "
+ + this.currentValue.getClass().getName()
+ + " instead.");
}
+ }
- /**
- * Returns the current object.
- *
- * @return the current object
- */
- public Object getPreMaterializedObject() {
- return this.currentValue;
+ @Override
+ public String getDefaultCLIString() {
+ return this.defaultCLIString;
+ }
+
+ /**
+ * Gets the command line interface text of the class.
+ *
+ * @param aClass
+ * the class
+ * @param requiredType
+ * the class type
+ * @return the command line interface text of the class
+ */
+ public static String classToCLIString(Class<?> aClass, Class<?> requiredType) {
+ String className = aClass.getName();
+ String packageName = requiredType.getPackage().getName();
+ if (className.startsWith(packageName)) {
+ // cut off package name
+ className = className.substring(packageName.length() + 1, className.length());
+ } else if (Task.class.isAssignableFrom(aClass)) {
+ packageName = Task.class.getPackage().getName();
+ if (className.startsWith(packageName)) {
+ // cut off task package name
+ className = className.substring(packageName.length() + 1,
+ className.length());
+ }
}
+ return className;
+ }
- /**
- * Gets the class type of this option.
- *
- * @return the class type of this option
- */
- public Class<?> getRequiredType() {
- return this.requiredType;
+ @Override
+ public abstract String getValueAsCLIString();
+
+ @Override
+ public abstract void setValueViaCLIString(String s);
+
+ // @Override
+ // public abstract JComponent getEditComponent();
+
+ /**
+ * Gets the class name without its package name prefix.
+ *
+ * @param className
+ * the name of the class
+ * @param expectedType
+ * the type of the class
+ * @return the class name without its package name prefix
+ */
+ public static String stripPackagePrefix(String className, Class<?> expectedType) {
+ if (className.startsWith(expectedType.getPackage().getName())) {
+ return className.substring(expectedType.getPackage().getName().length() + 1);
}
-
- /**
- * Gets the null string of this option.
- *
- * @return the null string of this option
- */
- public String getNullString() {
- return this.nullString;
- }
-
- /**
- * Gets a materialized object of this option.
- *
- * @param monitor the task monitor to use
- * @param repository the object repository to use
- * @return the materialized object
- */
- public Object materializeObject(TaskMonitor monitor,
- ObjectRepository repository) {
- if ((this.currentValue == null)
- || this.requiredType.isInstance(this.currentValue)) {
- return this.currentValue;
- } else if (this.currentValue instanceof String) {
- if (repository != null) {
- Object inmemObj = repository.getObjectNamed((String) this.currentValue);
- if (inmemObj == null) {
- throw new RuntimeException("No object named "
- + this.currentValue + " found in repository.");
- }
- return inmemObj;
- }
- throw new RuntimeException("No object repository available.");
- } else if (this.currentValue instanceof Task) {
- Task task = (Task) this.currentValue;
- Object result = task.doTask(monitor, repository);
- return result;
- } else if (this.currentValue instanceof File) {
- File inputFile = (File) this.currentValue;
- Object result = null;
- try {
- result = SerializeUtils.readFromFile(inputFile);
- } catch (Exception ex) {
- throw new RuntimeException("Problem loading "
- + this.requiredType.getName() + " object from file '"
- + inputFile.getName() + "':\n" + ex.getMessage(), ex);
- }
- return result;
- } else {
- throw new RuntimeException(
- "Could not materialize object of required type "
- + this.requiredType.getName() + ", found "
- + this.currentValue.getClass().getName()
- + " instead.");
- }
- }
-
- @Override
- public String getDefaultCLIString() {
- return this.defaultCLIString;
- }
-
- /**
- * Gets the command line interface text of the class.
- *
- * @param aClass the class
- * @param requiredType the class type
- * @return the command line interface text of the class
- */
- public static String classToCLIString(Class<?> aClass, Class<?> requiredType) {
- String className = aClass.getName();
- String packageName = requiredType.getPackage().getName();
- if (className.startsWith(packageName)) {
- // cut off package name
- className = className.substring(packageName.length() + 1, className.length());
- } else if (Task.class.isAssignableFrom(aClass)) {
- packageName = Task.class.getPackage().getName();
- if (className.startsWith(packageName)) {
- // cut off task package name
- className = className.substring(packageName.length() + 1,
- className.length());
- }
- }
- return className;
- }
-
- @Override
- public abstract String getValueAsCLIString();
-
- @Override
- public abstract void setValueViaCLIString(String s);
-
- //@Override
- //public abstract JComponent getEditComponent();
-
- /**
- * Gets the class name without its package name prefix.
- *
- * @param className the name of the class
- * @param expectedType the type of the class
- * @return the class name without its package name prefix
- */
- public static String stripPackagePrefix(String className, Class<?> expectedType) {
- if (className.startsWith(expectedType.getPackage().getName())) {
- return className.substring(expectedType.getPackage().getName().length() + 1);
- }
- return className;
- }
+ return className;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/AbstractOptionHandler.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/AbstractOptionHandler.java
index 22cc1f0..546e678 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/AbstractOptionHandler.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/AbstractOptionHandler.java
@@ -27,167 +27,142 @@
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
/**
- * Abstract Option Handler. All classes that have options in
- * MOA extend this class.
- *
+ * Abstract Option Handler. All classes that have options in MOA extend this
+ * class.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public abstract class AbstractOptionHandler extends AbstractMOAObject implements
- OptionHandler {
+ OptionHandler {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- /** Options to handle */
- //protected Options options;
+ /** Options to handle */
+ // protected Options options;
- /** Dictionary with option texts and objects */
- //protected Map<String, Object> classOptionNamesToPreparedObjects;
+ /** Dictionary with option texts and objects */
+ // protected Map<String, Object> classOptionNamesToPreparedObjects;
- @Override
- public String getPurposeString() {
- return "Anonymous object: purpose undocumented.";
- }
+ @Override
+ public String getPurposeString() {
+ return "Anonymous object: purpose undocumented.";
+ }
- @Override
- public Options getOptions() {
- /*if (this.options == null) {
- this.options = new Options();
- if (this.config== null) {
- this.config = new OptionsHandler(this, "");
- this.config.prepareForUse();
- }
- Option[] myOptions = this.config.discoverOptionsViaReflection();
- for (Option option : myOptions) {
- this.options.addOption(option);
- }
- }
- return this.options;*/
- if ( this.config == null){
- this.config = new OptionsHandler(this, "");
- //this.config.prepareForUse(monitor, repository);
- }
- return this.config.getOptions();
- }
-
- @Override
- public void prepareForUse() {
- prepareForUse(new NullMonitor(), null);
- }
-
- protected OptionsHandler config;
-
- @Override
- public void prepareForUse(TaskMonitor monitor, ObjectRepository repository) {
- //prepareClassOptions(monitor, repository);
- if ( this.config == null){
- this.config = new OptionsHandler(this, "");
- this.config.prepareForUse(monitor, repository);
- }
- prepareForUseImpl(monitor, repository);
- }
-
- /**
- * This method describes the implementation of how to prepare this object for use.
- * All classes that extends this class have to implement <code>prepareForUseImpl</code>
- * and not <code>prepareForUse</code> since
- * <code>prepareForUse</code> calls <code>prepareForUseImpl</code>.
- *
- * @param monitor the TaskMonitor to use
- * @param repository the ObjectRepository to use
+ @Override
+ public Options getOptions() {
+ /*
+ * if (this.options == null) { this.options = new Options(); if
+ * (this.config== null) { this.config = new OptionsHandler(this, "");
+ * this.config.prepareForUse(); } Option[] myOptions =
+ * this.config.discoverOptionsViaReflection(); for (Option option :
+ * myOptions) { this.options.addOption(option); } } return this.options;
*/
- protected abstract void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository);
-
- @Override
- public String getCLICreationString(Class<?> expectedType) {
- return ClassOption.stripPackagePrefix(this.getClass().getName(),
- expectedType)
- + " " + getOptions().getAsCLIString();
+ if (this.config == null) {
+ this.config = new OptionsHandler(this, "");
+ // this.config.prepareForUse(monitor, repository);
}
+ return this.config.getOptions();
+ }
- @Override
- public OptionHandler copy() {
- return (OptionHandler) super.copy();
+ @Override
+ public void prepareForUse() {
+ prepareForUse(new NullMonitor(), null);
+ }
+
+ protected OptionsHandler config;
+
+ @Override
+ public void prepareForUse(TaskMonitor monitor, ObjectRepository repository) {
+ // prepareClassOptions(monitor, repository);
+ if (this.config == null) {
+ this.config = new OptionsHandler(this, "");
+ this.config.prepareForUse(monitor, repository);
}
+ prepareForUseImpl(monitor, repository);
+ }
- /**
- * Gets the options of this class via reflection.
- *
- * @return an array of options
- */
- /* protected Option[] discoverOptionsViaReflection() {
- Class<? extends AbstractOptionHandler> c = this.getClass();
- Field[] fields = c.getFields();
- List<Option> optList = new LinkedList<Option>();
- for (Field field : fields) {
- String fName = field.getName();
- Class<?> fType = field.getType();
- if (fType.getName().endsWith("Option")) {
- if (Option.class.isAssignableFrom(fType)) {
- Option oVal = null;
- try {
- field.setAccessible(true);
- oVal = (Option) field.get(this);
- } catch (IllegalAccessException ignored) {
- // cannot access this field
- }
- if (oVal != null) {
- optList.add(oVal);
- }
- }
- }
- }
- return optList.toArray(new Option[optList.size()]);
- }*/
+ /**
+ * This method describes the implementation of how to prepare this object for
+ * use. All classes that extends this class have to implement
+ * <code>prepareForUseImpl</code> and not <code>prepareForUse</code> since
+ * <code>prepareForUse</code> calls <code>prepareForUseImpl</code>.
+ *
+ * @param monitor
+ * the TaskMonitor to use
+ * @param repository
+ * the ObjectRepository to use
+ */
+ protected abstract void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository);
- /**
- * Prepares the options of this class.
- *
- * @param monitor the TaskMonitor to use
- * @param repository the ObjectRepository to use
- */
- protected void prepareClassOptions(TaskMonitor monitor,
- ObjectRepository repository) {
- this.config.prepareClassOptions(monitor, repository);
- }/*
- this.classOptionNamesToPreparedObjects = null;
- Option[] optionArray = getOptions().getOptionArray();
- for (Option option : optionArray) {
- if (option instanceof ClassOption) {
- ClassOption classOption = (ClassOption) option;
- monitor.setCurrentActivity("Materializing option "
- + classOption.getName() + "...", -1.0);
- Object optionObj = classOption.materializeObject(monitor,
- repository);
- if (monitor.taskShouldAbort()) {
- return;
- }
- if (optionObj instanceof OptionHandler) {
- monitor.setCurrentActivity("Preparing option "
- + classOption.getName() + "...", -1.0);
- ((OptionHandler) optionObj).prepareForUse(monitor,
- repository);
- if (monitor.taskShouldAbort()) {
- return;
- }
- }
- if (this.classOptionNamesToPreparedObjects == null) {
- this.classOptionNamesToPreparedObjects = new HashMap<String, Object>();
- }
- this.classOptionNamesToPreparedObjects.put(option.getName(),
- optionObj);
- }
- }
- }*/
+ @Override
+ public String getCLICreationString(Class<?> expectedType) {
+ return ClassOption.stripPackagePrefix(this.getClass().getName(),
+ expectedType)
+ + " " + getOptions().getAsCLIString();
+ }
- /**
- * Gets a prepared option of this class.
- *
- * @param opt the class option to get
- * @return an option stored in the dictionary
- */
- protected Object getPreparedClassOption(ClassOption opt) {
- return this.config.getPreparedClassOption(opt);
- }
+ @Override
+ public OptionHandler copy() {
+ return (OptionHandler) super.copy();
+ }
+
+ /**
+ * Gets the options of this class via reflection.
+ *
+ * @return an array of options
+ */
+ /*
+ * protected Option[] discoverOptionsViaReflection() { Class<? extends
+ * AbstractOptionHandler> c = this.getClass(); Field[] fields = c.getFields();
+ * List<Option> optList = new LinkedList<Option>(); for (Field field : fields)
+ * { String fName = field.getName(); Class<?> fType = field.getType(); if
+ * (fType.getName().endsWith("Option")) { if
+ * (Option.class.isAssignableFrom(fType)) { Option oVal = null; try {
+ * field.setAccessible(true); oVal = (Option) field.get(this); } catch
+ * (IllegalAccessException ignored) { // cannot access this field } if (oVal
+ * != null) { optList.add(oVal); } } } } return optList.toArray(new
+ * Option[optList.size()]); }
+ */
+
+ /**
+ * Prepares the options of this class.
+ *
+ * @param monitor
+ * the TaskMonitor to use
+ * @param repository
+ * the ObjectRepository to use
+ */
+ protected void prepareClassOptions(TaskMonitor monitor,
+ ObjectRepository repository) {
+ this.config.prepareClassOptions(monitor, repository);
+ }/*
+ * this.classOptionNamesToPreparedObjects = null; Option[] optionArray =
+ * getOptions().getOptionArray(); for (Option option : optionArray) { if
+ * (option instanceof ClassOption) { ClassOption classOption = (ClassOption)
+ * option; monitor.setCurrentActivity("Materializing option " +
+ * classOption.getName() + "...", -1.0); Object optionObj =
+ * classOption.materializeObject(monitor, repository); if
+ * (monitor.taskShouldAbort()) { return; } if (optionObj instanceof
+ * OptionHandler) { monitor.setCurrentActivity("Preparing option " +
+ * classOption.getName() + "...", -1.0); ((OptionHandler)
+ * optionObj).prepareForUse(monitor, repository); if
+ * (monitor.taskShouldAbort()) { return; } } if
+ * (this.classOptionNamesToPreparedObjects == null) {
+ * this.classOptionNamesToPreparedObjects = new HashMap<String, Object>(); }
+ * this.classOptionNamesToPreparedObjects.put(option.getName(), optionObj); }
+ * } }
+ */
+
+ /**
+ * Gets a prepared option of this class.
+ *
+ * @param opt
+ * the class option to get
+ * @return an option stored in the dictionary
+ */
+ protected Object getPreparedClassOption(ClassOption opt) {
+ return this.config.getPreparedClassOption(opt);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/ClassOption.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/ClassOption.java
index 64eff95..3e8e62c 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/ClassOption.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/ClassOption.java
@@ -28,148 +28,149 @@
/**
* Class option.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class ClassOption extends AbstractClassOption {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public ClassOption(String name, char cliChar, String purpose,
- Class<?> requiredType, String defaultCLIString) {
- super(name, cliChar, purpose, requiredType, defaultCLIString);
+ public ClassOption(String name, char cliChar, String purpose,
+ Class<?> requiredType, String defaultCLIString) {
+ super(name, cliChar, purpose, requiredType, defaultCLIString);
+ }
+
+ public ClassOption(String name, char cliChar, String purpose,
+ Class<?> requiredType, String defaultCLIString, String nullString) {
+ super(name, cliChar, purpose, requiredType, defaultCLIString, nullString);
+ }
+
+ @Override
+ public String getValueAsCLIString() {
+ if ((this.currentValue == null) && (this.nullString != null)) {
+ return this.nullString;
}
+ return objectToCLIString(this.currentValue, this.requiredType);
+ }
- public ClassOption(String name, char cliChar, String purpose,
- Class<?> requiredType, String defaultCLIString, String nullString) {
- super(name, cliChar, purpose, requiredType, defaultCLIString, nullString);
+ @Override
+ public void setValueViaCLIString(String s) {
+ if ((this.nullString != null)
+ && ((s == null) || (s.length() == 0) || s.equals(this.nullString))) {
+ this.currentValue = null;
+ } else {
+ try {
+ this.currentValue = cliStringToObject(s, this.requiredType,
+ null);
+ } catch (Exception e) {
+ throw new IllegalArgumentException("Problems with option: " + getName(), e);
+ }
}
+ }
- @Override
- public String getValueAsCLIString() {
- if ((this.currentValue == null) && (this.nullString != null)) {
- return this.nullString;
- }
- return objectToCLIString(this.currentValue, this.requiredType);
+ public static String objectToCLIString(Object obj, Class<?> requiredType) {
+ if (obj == null) {
+ return "";
}
-
- @Override
- public void setValueViaCLIString(String s) {
- if ((this.nullString != null)
- && ((s == null) || (s.length() == 0) || s.equals(this.nullString))) {
- this.currentValue = null;
- } else {
- try {
- this.currentValue = cliStringToObject(s, this.requiredType,
- null);
- } catch (Exception e) {
- throw new IllegalArgumentException("Problems with option: " + getName(), e);
- }
- }
+ if (obj instanceof File) {
+ return (FILE_PREFIX_STRING + ((File) obj).getPath());
}
-
- public static String objectToCLIString(Object obj, Class<?> requiredType) {
- if (obj == null) {
- return "";
- }
- if (obj instanceof File) {
- return (FILE_PREFIX_STRING + ((File) obj).getPath());
- }
- if (obj instanceof String) {
- return (INMEM_PREFIX_STRING + obj);
- }
- String className = classToCLIString(obj.getClass(), requiredType);
- if (obj instanceof OptionHandler) {
- String subOptions = ((OptionHandler) obj).getOptions().getAsCLIString();
- if (subOptions.length() > 0) {
- return (className + " " + subOptions);
- }
- }
- return className;
+ if (obj instanceof String) {
+ return (INMEM_PREFIX_STRING + obj);
}
+ String className = classToCLIString(obj.getClass(), requiredType);
+ if (obj instanceof OptionHandler) {
+ String subOptions = ((OptionHandler) obj).getOptions().getAsCLIString();
+ if (subOptions.length() > 0) {
+ return (className + " " + subOptions);
+ }
+ }
+ return className;
+ }
- public static Object cliStringToObject(String cliString,
- Class<?> requiredType, Option[] externalOptions) throws Exception {
- if (cliString.startsWith(FILE_PREFIX_STRING)) {
- return new File(cliString.substring(FILE_PREFIX_STRING.length()));
- }
- if (cliString.startsWith(INMEM_PREFIX_STRING)) {
- return cliString.substring(INMEM_PREFIX_STRING.length());
- }
- cliString = cliString.trim();
- int firstSpaceIndex = cliString.indexOf(' ', 0);
- String className;
- String classOptions;
- if (firstSpaceIndex > 0) {
- className = cliString.substring(0, firstSpaceIndex);
- classOptions = cliString.substring(firstSpaceIndex + 1, cliString.length());
- classOptions = classOptions.trim();
- } else {
- className = cliString;
- classOptions = "";
- }
- Class<?> classObject;
+ public static Object cliStringToObject(String cliString,
+ Class<?> requiredType, Option[] externalOptions) throws Exception {
+ if (cliString.startsWith(FILE_PREFIX_STRING)) {
+ return new File(cliString.substring(FILE_PREFIX_STRING.length()));
+ }
+ if (cliString.startsWith(INMEM_PREFIX_STRING)) {
+ return cliString.substring(INMEM_PREFIX_STRING.length());
+ }
+ cliString = cliString.trim();
+ int firstSpaceIndex = cliString.indexOf(' ', 0);
+ String className;
+ String classOptions;
+ if (firstSpaceIndex > 0) {
+ className = cliString.substring(0, firstSpaceIndex);
+ classOptions = cliString.substring(firstSpaceIndex + 1, cliString.length());
+ classOptions = classOptions.trim();
+ } else {
+ className = cliString;
+ classOptions = "";
+ }
+ Class<?> classObject;
+ try {
+ classObject = Class.forName(className);
+ } catch (Throwable t1) {
+ try {
+ // try prepending default package
+ classObject = Class.forName(requiredType.getPackage().getName()
+ + "." + className);
+ } catch (Throwable t2) {
try {
- classObject = Class.forName(className);
- } catch (Throwable t1) {
- try {
- // try prepending default package
- classObject = Class.forName(requiredType.getPackage().getName()
- + "." + className);
- } catch (Throwable t2) {
- try {
- // try prepending task package
- classObject = Class.forName(Task.class.getPackage().getName()
- + "." + className);
- } catch (Throwable t3) {
- throw new Exception("Class not found: " + className);
- }
- }
+ // try prepending task package
+ classObject = Class.forName(Task.class.getPackage().getName()
+ + "." + className);
+ } catch (Throwable t3) {
+ throw new Exception("Class not found: " + className);
}
- Object classInstance;
- try {
- classInstance = classObject.newInstance();
- } catch (Exception ex) {
- throw new Exception("Problem creating instance of class: "
- + className, ex);
- }
- if (requiredType.isInstance(classInstance)
- || ((classInstance instanceof Task) && requiredType.isAssignableFrom(((Task) classInstance).getTaskResultType()))) {
- Options options = new Options();
- if (externalOptions != null) {
- for (Option option : externalOptions) {
- options.addOption(option);
- }
- }
- if (classInstance instanceof OptionHandler) {
- Option[] objectOptions = ((OptionHandler) classInstance).getOptions().getOptionArray();
- for (Option option : objectOptions) {
- options.addOption(option);
- }
- }
- try {
- options.setViaCLIString(classOptions);
- } catch (Exception ex) {
- throw new Exception("Problem with options to '"
- + className
- + "'."
- + "\n\nValid options for "
- + className
- + ":\n"
- + ((OptionHandler) classInstance).getOptions().getHelpString(), ex);
- } finally {
- options.removeAllOptions(); // clean up listener refs
- }
- } else {
- throw new Exception("Class named '" + className
- + "' is not an instance of " + requiredType.getName() + ".");
- }
- return classInstance;
+ }
}
+ Object classInstance;
+ try {
+ classInstance = classObject.newInstance();
+ } catch (Exception ex) {
+ throw new Exception("Problem creating instance of class: "
+ + className, ex);
+ }
+ if (requiredType.isInstance(classInstance)
+ || ((classInstance instanceof Task) && requiredType
+ .isAssignableFrom(((Task) classInstance).getTaskResultType()))) {
+ Options options = new Options();
+ if (externalOptions != null) {
+ for (Option option : externalOptions) {
+ options.addOption(option);
+ }
+ }
+ if (classInstance instanceof OptionHandler) {
+ Option[] objectOptions = ((OptionHandler) classInstance).getOptions().getOptionArray();
+ for (Option option : objectOptions) {
+ options.addOption(option);
+ }
+ }
+ try {
+ options.setViaCLIString(classOptions);
+ } catch (Exception ex) {
+ throw new Exception("Problem with options to '"
+ + className
+ + "'."
+ + "\n\nValid options for "
+ + className
+ + ":\n"
+ + ((OptionHandler) classInstance).getOptions().getHelpString(), ex);
+ } finally {
+ options.removeAllOptions(); // clean up listener refs
+ }
+ } else {
+ throw new Exception("Class named '" + className
+ + "' is not an instance of " + requiredType.getName() + ".");
+ }
+ return classInstance;
+ }
- //@Override
- //public JComponent getEditComponent() {
- // return new ClassOptionEditComponent(this);
- //}
+ // @Override
+ // public JComponent getEditComponent() {
+ // return new ClassOptionEditComponent(this);
+ // }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/OptionHandler.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/OptionHandler.java
index b88cada..699303b 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/OptionHandler.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/OptionHandler.java
@@ -27,53 +27,55 @@
import com.yahoo.labs.samoa.moa.tasks.TaskMonitor;
/**
- * Interface representing an object that handles options or parameters.
- *
+ * Interface representing an object that handles options or parameters.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
- * @version $Revision: 7 $
+ * @version $Revision: 7 $
*/
public interface OptionHandler extends MOAObject, Configurable {
- /**
- * Gets the purpose of this object
- *
- * @return the string with the purpose of this object
- */
- public String getPurposeString();
+ /**
+ * Gets the purpose of this object
+ *
+ * @return the string with the purpose of this object
+ */
+ public String getPurposeString();
- /**
- * Gets the options of this object
- *
- * @return the options of this object
- */
- public Options getOptions();
+ /**
+ * Gets the options of this object
+ *
+ * @return the options of this object
+ */
+ public Options getOptions();
- /**
- * This method prepares this object for use.
- *
- */
- public void prepareForUse();
+ /**
+ * This method prepares this object for use.
+ *
+ */
+ public void prepareForUse();
- /**
- * This method prepares this object for use.
- *
- * @param monitor the TaskMonitor to use
- * @param repository the ObjectRepository to use
- */
- public void prepareForUse(TaskMonitor monitor, ObjectRepository repository);
+ /**
+ * This method prepares this object for use.
+ *
+ * @param monitor
+ * the TaskMonitor to use
+ * @param repository
+ * the ObjectRepository to use
+ */
+ public void prepareForUse(TaskMonitor monitor, ObjectRepository repository);
- /**
- * This method produces a copy of this object.
- *
- * @return a copy of this object
- */
- @Override
- public OptionHandler copy();
+ /**
+ * This method produces a copy of this object.
+ *
+ * @return a copy of this object
+ */
+ @Override
+ public OptionHandler copy();
- /**
- * Gets the Command Line Interface text to create the object
- *
- * @return the Command Line Interface text to create the object
- */
- public String getCLICreationString(Class<?> expectedType);
+ /**
+ * Gets the Command Line Interface text to create the object
+ *
+ * @return the Command Line Interface text to create the object
+ */
+ public String getCLICreationString(Class<?> expectedType);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/OptionsHandler.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/OptionsHandler.java
index c643115..42628b2 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/OptionsHandler.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/options/OptionsHandler.java
@@ -20,7 +20,6 @@
* #L%
*/
-
import java.util.HashMap;
import com.github.javacliparser.JavaCLIParser;
import com.github.javacliparser.Option;
@@ -34,169 +33,151 @@
*/
/**
- *
+ *
* @author abifet
*/
public class OptionsHandler extends JavaCLIParser {
- //public Object handler;
-
- public OptionsHandler(Object c, String cliString) {
- super(c,cliString);
- //this.handler = c;
- //this.prepareForUse();
- /*int firstSpaceIndex = cliString.indexOf(' ', 0);
- String classOptions;
- String className;
- if (firstSpaceIndex > 0) {
- className = cliString.substring(0, firstSpaceIndex);
- classOptions = cliString.substring(firstSpaceIndex + 1, cliString.length());
- classOptions = classOptions.trim();
- } else {
- className = cliString;
- classOptions = "";
- }*/
- //options.setViaCLIString(cliString);
- }
-
-
- //private static final long serialVersionUID = 1L;
+ // public Object handler;
- /** Options to handle */
- //protected Options options;
+ public OptionsHandler(Object c, String cliString) {
+ super(c, cliString);
+ // this.handler = c;
+ // this.prepareForUse();
+ /*
+ * int firstSpaceIndex = cliString.indexOf(' ', 0); String classOptions;
+ * String className; if (firstSpaceIndex > 0) { className =
+ * cliString.substring(0, firstSpaceIndex); classOptions =
+ * cliString.substring(firstSpaceIndex + 1, cliString.length());
+ * classOptions = classOptions.trim(); } else { className = cliString;
+ * classOptions = ""; }
+ */
+ // options.setViaCLIString(cliString);
+ }
- /** Dictionary with option texts and objects */
- //protected Map<String, Object> classOptionNamesToPreparedObjects;
+ // private static final long serialVersionUID = 1L;
+ /** Options to handle */
+ // protected Options options;
- /*public String getPurposeString() {
- return "Anonymous object: purpose undocumented.";
- }
+ /** Dictionary with option texts and objects */
+ // protected Map<String, Object> classOptionNamesToPreparedObjects;
- public Options getOptions() {
- if (this.options == null) {
- this.options = new Options();
- Option[] myOptions = discoverOptionsViaReflection();
- for (Option option : myOptions) {
- this.options.addOption(option);
- }
+ /*
+ * public String getPurposeString() { return
+ * "Anonymous object: purpose undocumented."; }
+ *
+ * public Options getOptions() { if (this.options == null) { this.options =
+ * new Options(); Option[] myOptions = discoverOptionsViaReflection(); for
+ * (Option option : myOptions) { this.options.addOption(option); } } return
+ * this.options; }
+ */
+
+ public void prepareForUse() {
+ prepareForUse(new NullMonitor(), null);
+ }
+
+ public void prepareForUse(TaskMonitor monitor, ObjectRepository repository) {
+ prepareClassOptions(monitor, repository);
+ // prepareForUseImpl(monitor, repository);
+ }
+
+ /**
+ * This method describes the implementation of how to prepare this object for
+ * use. All classes that extends this class have to implement
+ * <code>prepareForUseImpl</code> and not <code>prepareForUse</code> since
+ * <code>prepareForUse</code> calls <code>prepareForUseImpl</code>.
+ *
+ * @param monitor
+ * the TaskMonitor to use
+ * @param repository
+ * the ObjectRepository to use
+ */
+ // protected abstract void prepareForUseImpl(TaskMonitor monitor,
+ // ObjectRepository repository);
+
+ /*
+ * public String getCLICreationString(Class<?> expectedType) { return
+ * ClassOption.stripPackagePrefix(this.getClass().getName(), expectedType) +
+ * " " + getOptions().getAsCLIString(); }
+ */
+
+ /**
+ * Gets the options of this class via reflection.
+ *
+ * @return an array of options
+ */
+ /*
+ * public Option[] discoverOptionsViaReflection() { //Class<? extends
+ * AbstractOptionHandler> c = this.getClass(); Class c =
+ * this.handler.getClass(); Field[] fields = c.getFields(); List<Option>
+ * optList = new LinkedList<Option>(); for (Field field : fields) { String
+ * fName = field.getName(); Class<?> fType = field.getType(); if
+ * (fType.getName().endsWith("Option")) { if
+ * (Option.class.isAssignableFrom(fType)) { Option oVal = null; try {
+ * field.setAccessible(true); oVal = (Option) field.get(this.handler); } catch
+ * (IllegalAccessException ignored) { // cannot access this field } if (oVal
+ * != null) { optList.add(oVal); } } } } return optList.toArray(new
+ * Option[optList.size()]); }
+ */
+
+ /**
+ * Prepares the options of this class.
+ *
+ * @param monitor
+ * the TaskMonitor to use
+ * @param repository
+ * the ObjectRepository to use
+ */
+ public void prepareClassOptions(TaskMonitor monitor,
+ ObjectRepository repository) {
+ this.classOptionNamesToPreparedObjects = null;
+ Option[] optionArray = getOptions().getOptionArray();
+ for (Option option : optionArray) {
+ if (option instanceof ClassOption) {
+ ClassOption classOption = (ClassOption) option;
+ monitor.setCurrentActivity("Materializing option "
+ + classOption.getName() + "...", -1.0);
+ Object optionObj = classOption.materializeObject(monitor,
+ repository);
+ if (monitor.taskShouldAbort()) {
+ return;
}
- return this.options;
- }*/
-
- public void prepareForUse() {
- prepareForUse(new NullMonitor(), null);
- }
-
- public void prepareForUse(TaskMonitor monitor, ObjectRepository repository) {
- prepareClassOptions(monitor, repository);
- //prepareForUseImpl(monitor, repository);
- }
-
- /**
- * This method describes the implementation of how to prepare this object for use.
- * All classes that extends this class have to implement <code>prepareForUseImpl</code>
- * and not <code>prepareForUse</code> since
- * <code>prepareForUse</code> calls <code>prepareForUseImpl</code>.
- *
- * @param monitor the TaskMonitor to use
- * @param repository the ObjectRepository to use
- */
- //protected abstract void prepareForUseImpl(TaskMonitor monitor,
- // ObjectRepository repository);
-
- /* public String getCLICreationString(Class<?> expectedType) {
- return ClassOption.stripPackagePrefix(this.getClass().getName(),
- expectedType)
- + " " + getOptions().getAsCLIString();
- }*/
-
-
- /**
- * Gets the options of this class via reflection.
- *
- * @return an array of options
- */
- /*public Option[] discoverOptionsViaReflection() {
- //Class<? extends AbstractOptionHandler> c = this.getClass();
- Class c = this.handler.getClass();
- Field[] fields = c.getFields();
- List<Option> optList = new LinkedList<Option>();
- for (Field field : fields) {
- String fName = field.getName();
- Class<?> fType = field.getType();
- if (fType.getName().endsWith("Option")) {
- if (Option.class.isAssignableFrom(fType)) {
- Option oVal = null;
- try {
- field.setAccessible(true);
- oVal = (Option) field.get(this.handler);
- } catch (IllegalAccessException ignored) {
- // cannot access this field
- }
- if (oVal != null) {
- optList.add(oVal);
- }
- }
- }
+ if (optionObj instanceof OptionHandler) {
+ monitor.setCurrentActivity("Preparing option "
+ + classOption.getName() + "...", -1.0);
+ ((OptionHandler) optionObj).prepareForUse(monitor,
+ repository);
+ if (monitor.taskShouldAbort()) {
+ return;
+ }
}
- return optList.toArray(new Option[optList.size()]);
- }*/
-
- /**
- * Prepares the options of this class.
- *
- * @param monitor the TaskMonitor to use
- * @param repository the ObjectRepository to use
- */
- public void prepareClassOptions(TaskMonitor monitor,
- ObjectRepository repository) {
- this.classOptionNamesToPreparedObjects = null;
- Option[] optionArray = getOptions().getOptionArray();
- for (Option option : optionArray) {
- if (option instanceof ClassOption) {
- ClassOption classOption = (ClassOption) option;
- monitor.setCurrentActivity("Materializing option "
- + classOption.getName() + "...", -1.0);
- Object optionObj = classOption.materializeObject(monitor,
- repository);
- if (monitor.taskShouldAbort()) {
- return;
- }
- if (optionObj instanceof OptionHandler) {
- monitor.setCurrentActivity("Preparing option "
- + classOption.getName() + "...", -1.0);
- ((OptionHandler) optionObj).prepareForUse(monitor,
- repository);
- if (monitor.taskShouldAbort()) {
- return;
- }
- }
- if (this.classOptionNamesToPreparedObjects == null) {
- this.classOptionNamesToPreparedObjects = new HashMap<String, Object>();
- }
- this.classOptionNamesToPreparedObjects.put(option.getName(),
- optionObj);
- }
- }
- }
-
- /**
- * Gets a prepared option of this class.
- *
- * @param opt the class option to get
- * @return an option stored in the dictionary
- */
- public Object getPreparedClassOption(ClassOption opt) {
if (this.classOptionNamesToPreparedObjects == null) {
- this.prepareForUse();
- }
- return this.classOptionNamesToPreparedObjects.get(opt.getName());
+ this.classOptionNamesToPreparedObjects = new HashMap<String, Object>();
+ }
+ this.classOptionNamesToPreparedObjects.put(option.getName(),
+ optionObj);
+ }
}
+ }
- //@Override
- //public void getDescription(StringBuilder sb, int i) {
- // throw new UnsupportedOperationException("Not supported yet.");
- //}
-
+ /**
+ * Gets a prepared option of this class.
+ *
+ * @param opt
+ * the class option to get
+ * @return an option stored in the dictionary
+ */
+ public Object getPreparedClassOption(ClassOption opt) {
+ if (this.classOptionNamesToPreparedObjects == null) {
+ this.prepareForUse();
+ }
+ return this.classOptionNamesToPreparedObjects.get(opt.getName());
+ }
+
+ // @Override
+ // public void getDescription(StringBuilder sb, int i) {
+ // throw new UnsupportedOperationException("Not supported yet.");
+ // }
+
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/ArffFileStream.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/ArffFileStream.java
index aa2717d..0edd48f 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/ArffFileStream.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/ArffFileStream.java
@@ -39,161 +39,161 @@
/**
* Stream reader of ARFF files.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class ArffFileStream extends AbstractOptionHandler implements InstanceStream {
- @Override
- public String getPurposeString() {
- return "A stream read from an ARFF file.";
+ @Override
+ public String getPurposeString() {
+ return "A stream read from an ARFF file.";
+ }
+
+ private static final long serialVersionUID = 1L;
+
+ public FileOption arffFileOption = new FileOption("arffFile", 'f',
+ "ARFF file to load.", null, "arff", false);
+
+ public IntOption classIndexOption = new IntOption(
+ "classIndex",
+ 'c',
+ "Class index of data. 0 for none or -1 for last attribute in file.",
+ -1, -1, Integer.MAX_VALUE);
+
+ protected Instances instances;
+
+ transient protected Reader fileReader;
+
+ protected boolean hitEndOfFile;
+
+ protected InstanceExample lastInstanceRead;
+
+ protected int numInstancesRead;
+
+ transient protected InputStreamProgressMonitor fileProgressMonitor;
+
+ protected boolean hasStarted;
+
+ public ArffFileStream() {
+ }
+
+ public ArffFileStream(String arffFileName, int classIndex) {
+ this.arffFileOption.setValue(arffFileName);
+ this.classIndexOption.setValue(classIndex);
+ this.hasStarted = false;
+ restart();
+ }
+
+ @Override
+ public void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ // restart();
+ this.hasStarted = false;
+ this.lastInstanceRead = null;
+ }
+
+ @Override
+ public InstancesHeader getHeader() {
+ return new InstancesHeader(this.instances);
+ }
+
+ @Override
+ public long estimatedRemainingInstances() {
+ double progressFraction = this.fileProgressMonitor.getProgressFraction();
+ if ((progressFraction > 0.0) && (this.numInstancesRead > 0)) {
+ return (long) ((this.numInstancesRead / progressFraction) - this.numInstancesRead);
}
+ return -1;
+ }
- private static final long serialVersionUID = 1L;
+ @Override
+ public boolean hasMoreInstances() {
+ return !this.hitEndOfFile;
+ }
- public FileOption arffFileOption = new FileOption("arffFile", 'f',
- "ARFF file to load.", null, "arff", false);
-
- public IntOption classIndexOption = new IntOption(
- "classIndex",
- 'c',
- "Class index of data. 0 for none or -1 for last attribute in file.",
- -1, -1, Integer.MAX_VALUE);
-
- protected Instances instances;
-
- transient protected Reader fileReader;
-
- protected boolean hitEndOfFile;
-
- protected InstanceExample lastInstanceRead;
-
- protected int numInstancesRead;
-
- transient protected InputStreamProgressMonitor fileProgressMonitor;
-
- protected boolean hasStarted;
-
- public ArffFileStream() {
+ @Override
+ public InstanceExample nextInstance() {
+ if (this.lastInstanceRead == null) {
+ readNextInstanceFromFile();
}
+ InstanceExample prevInstance = this.lastInstanceRead;
+ this.hitEndOfFile = !readNextInstanceFromFile();
+ return prevInstance;
+ }
- public ArffFileStream(String arffFileName, int classIndex) {
- this.arffFileOption.setValue(arffFileName);
- this.classIndexOption.setValue(classIndex);
- this.hasStarted = false;
- restart();
+ @Override
+ public boolean isRestartable() {
+ return true;
+ }
+
+ @Override
+ public void restart() {
+ try {
+ reset();
+ // this.hitEndOfFile = !readNextInstanceFromFile();
+ } catch (IOException ioe) {
+ throw new RuntimeException("ArffFileStream restart failed.", ioe);
}
+ }
- @Override
- public void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- //restart();
- this.hasStarted = false;
- this.lastInstanceRead = null;
+ protected boolean readNextInstanceFromFile() {
+ boolean ret;
+ if (!this.hasStarted) {
+ try {
+ reset();
+ ret = getNextInstanceFromFile();
+ this.hitEndOfFile = !ret;
+ } catch (IOException ioe) {
+ throw new RuntimeException("ArffFileStream restart failed.", ioe);
+ }
+ this.hasStarted = true;
+ } else {
+ ret = getNextInstanceFromFile();
}
+ return ret;
+ }
- @Override
- public InstancesHeader getHeader() {
- return new InstancesHeader(this.instances);
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
+
+ private void reset() throws IOException {
+ if (this.fileReader != null) {
+ this.fileReader.close();
}
-
- @Override
- public long estimatedRemainingInstances() {
- double progressFraction = this.fileProgressMonitor.getProgressFraction();
- if ((progressFraction > 0.0) && (this.numInstancesRead > 0)) {
- return (long) ((this.numInstancesRead / progressFraction) - this.numInstancesRead);
- }
- return -1;
+ InputStream fileStream = new FileInputStream(this.arffFileOption.getFile());
+ this.fileProgressMonitor = new InputStreamProgressMonitor(
+ fileStream);
+ this.fileReader = new BufferedReader(new InputStreamReader(
+ this.fileProgressMonitor));
+ this.instances = new Instances(this.fileReader, 1, this.classIndexOption.getValue());
+ if (this.classIndexOption.getValue() < 0) {
+ this.instances.setClassIndex(this.instances.numAttributes() - 1);
+ } else if (this.classIndexOption.getValue() > 0) {
+ this.instances.setClassIndex(this.classIndexOption.getValue() - 1);
}
+ this.numInstancesRead = 0;
+ this.lastInstanceRead = null;
+ }
- @Override
- public boolean hasMoreInstances() {
- return !this.hitEndOfFile;
- }
-
- @Override
- public InstanceExample nextInstance() {
- if (this.lastInstanceRead == null) {
- readNextInstanceFromFile();
- }
- InstanceExample prevInstance = this.lastInstanceRead;
- this.hitEndOfFile = !readNextInstanceFromFile();
- return prevInstance;
- }
-
- @Override
- public boolean isRestartable() {
+ private boolean getNextInstanceFromFile() throws RuntimeException {
+ try {
+ if (this.instances.readInstance(this.fileReader)) {
+ this.lastInstanceRead = new InstanceExample(this.instances.instance(0));
+ this.instances.delete(); // keep instances clean
+ this.numInstancesRead++;
return true;
+ }
+ if (this.fileReader != null) {
+ this.fileReader.close();
+ this.fileReader = null;
+ }
+ return false;
+ } catch (IOException ioe) {
+ throw new RuntimeException(
+ "ArffFileStream failed to read instance from stream.", ioe);
}
-
- @Override
- public void restart() {
- try {
- reset();
- //this.hitEndOfFile = !readNextInstanceFromFile();
- } catch (IOException ioe) {
- throw new RuntimeException("ArffFileStream restart failed.", ioe);
- }
- }
-
- protected boolean readNextInstanceFromFile() {
- boolean ret;
- if (!this.hasStarted){
- try {
- reset();
- ret = getNextInstanceFromFile();
- this.hitEndOfFile = !ret;
- } catch (IOException ioe) {
- throw new RuntimeException("ArffFileStream restart failed.", ioe);
- }
- this.hasStarted = true;
- } else {
- ret = getNextInstanceFromFile();
- }
- return ret;
- }
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
-
- private void reset() throws IOException {
- if (this.fileReader != null) {
- this.fileReader.close();
- }
- InputStream fileStream = new FileInputStream(this.arffFileOption.getFile());
- this.fileProgressMonitor = new InputStreamProgressMonitor(
- fileStream);
- this.fileReader = new BufferedReader(new InputStreamReader(
- this.fileProgressMonitor));
- this.instances = new Instances(this.fileReader, 1, this.classIndexOption.getValue());
- if (this.classIndexOption.getValue() < 0) {
- this.instances.setClassIndex(this.instances.numAttributes() - 1);
- } else if (this.classIndexOption.getValue() > 0) {
- this.instances.setClassIndex(this.classIndexOption.getValue() - 1);
- }
- this.numInstancesRead = 0;
- this.lastInstanceRead = null;
- }
-
- private boolean getNextInstanceFromFile() throws RuntimeException {
- try {
- if (this.instances.readInstance(this.fileReader)) {
- this.lastInstanceRead = new InstanceExample(this.instances.instance(0));
- this.instances.delete(); // keep instances clean
- this.numInstancesRead++;
- return true;
- }
- if (this.fileReader != null) {
- this.fileReader.close();
- this.fileReader = null;
- }
- return false;
- } catch (IOException ioe) {
- throw new RuntimeException(
- "ArffFileStream failed to read instance from stream.", ioe);
- }
- }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/ExampleStream.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/ExampleStream.java
index dbdeba3..471d83c 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/ExampleStream.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/ExampleStream.java
@@ -25,55 +25,54 @@
import com.yahoo.labs.samoa.moa.core.Example;
/**
- * Interface representing a data stream of examples.
- *
+ * Interface representing a data stream of examples.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
- * @version $Revision: 7 $
+ * @version $Revision: 7 $
*/
public interface ExampleStream<E extends Example> extends MOAObject {
- /**
- * Gets the header of this stream.
- * This is useful to know attributes and classes.
- * InstancesHeader is an extension of weka.Instances.
- *
- * @return the header of this stream
- */
- public InstancesHeader getHeader();
+ /**
+ * Gets the header of this stream. This is useful to know attributes and
+ * classes. InstancesHeader is an extension of weka.Instances.
+ *
+ * @return the header of this stream
+ */
+ public InstancesHeader getHeader();
- /**
- * Gets the estimated number of remaining instances in this stream
- *
- * @return the estimated number of instances to get from this stream
- */
- public long estimatedRemainingInstances();
+ /**
+ * Gets the estimated number of remaining instances in this stream
+ *
+ * @return the estimated number of instances to get from this stream
+ */
+ public long estimatedRemainingInstances();
- /**
- * Gets whether this stream has more instances to output.
- * This is useful when reading streams from files.
- *
- * @return true if this stream has more instances to output
- */
- public boolean hasMoreInstances();
+ /**
+ * Gets whether this stream has more instances to output. This is useful when
+ * reading streams from files.
+ *
+ * @return true if this stream has more instances to output
+ */
+ public boolean hasMoreInstances();
- /**
- * Gets the next example from this stream.
- *
- * @return the next example of this stream
- */
- public E nextInstance();
+ /**
+ * Gets the next example from this stream.
+ *
+ * @return the next example of this stream
+ */
+ public E nextInstance();
- /**
- * Gets whether this stream can restart.
- *
- * @return true if this stream can restart
- */
- public boolean isRestartable();
+ /**
+ * Gets whether this stream can restart.
+ *
+ * @return true if this stream can restart
+ */
+ public boolean isRestartable();
- /**
- * Restarts this stream. It must be similar to
- * starting a new stream from scratch.
- *
- */
- public void restart();
+ /**
+ * Restarts this stream. It must be similar to starting a new stream from
+ * scratch.
+ *
+ */
+ public void restart();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/InstanceStream.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/InstanceStream.java
index 1166e81..30289bd 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/InstanceStream.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/InstanceStream.java
@@ -24,14 +24,11 @@
import com.yahoo.labs.samoa.instances.Instance;
/**
- * Interface representing a data stream of instances.
- *
+ * Interface representing a data stream of instances.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
- * @version $Revision: 7 $
+ * @version $Revision: 7 $
*/
public interface InstanceStream extends ExampleStream<Example<Instance>> {
-
-
-
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusterEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusterEvent.java
index 7062d99..1bada9d 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusterEvent.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusterEvent.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.streams.clustering;
/*
@@ -36,15 +35,15 @@
this.timestamp = timestamp;
}
- public String getMessage(){
- return message;
+ public String getMessage() {
+ return message;
}
- public long getTimestamp(){
- return timestamp;
+ public long getTimestamp() {
+ return timestamp;
}
- public String getType(){
- return type;
+ public String getType() {
+ return type;
}
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusterEventListener.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusterEventListener.java
index ff05afd..07d9aca 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusterEventListener.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusterEventListener.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.streams.clustering;
/*
@@ -28,4 +27,3 @@
public void changeCluster(ClusterEvent e);
}
-
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusteringStream.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusteringStream.java
index 6e38348..3db9ec6 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusteringStream.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/ClusteringStream.java
@@ -1,4 +1,3 @@
-
package com.yahoo.labs.samoa.moa.streams.clustering;
/*
@@ -26,30 +25,29 @@
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.moa.streams.InstanceStream;
-public abstract class ClusteringStream extends AbstractOptionHandler implements InstanceStream{
- public IntOption decayHorizonOption = new IntOption("decayHorizon", 'h',
- "Decay horizon", 1000, 0, Integer.MAX_VALUE);
+public abstract class ClusteringStream extends AbstractOptionHandler implements InstanceStream {
+ public IntOption decayHorizonOption = new IntOption("decayHorizon", 'h',
+ "Decay horizon", 1000, 0, Integer.MAX_VALUE);
- public FloatOption decayThresholdOption = new FloatOption("decayThreshold", 't',
- "Decay horizon threshold", 0.01, 0, 1);
+ public FloatOption decayThresholdOption = new FloatOption("decayThreshold", 't',
+ "Decay horizon threshold", 0.01, 0, 1);
- public IntOption evaluationFrequencyOption = new IntOption("evaluationFrequency", 'e',
- "Evaluation frequency", 1000, 0, Integer.MAX_VALUE);
+ public IntOption evaluationFrequencyOption = new IntOption("evaluationFrequency", 'e',
+ "Evaluation frequency", 1000, 0, Integer.MAX_VALUE);
- public IntOption numAttsOption = new IntOption("numAtts", 'a',
- "The number of attributes to generate.", 2, 0, Integer.MAX_VALUE);
+ public IntOption numAttsOption = new IntOption("numAtts", 'a',
+ "The number of attributes to generate.", 2, 0, Integer.MAX_VALUE);
- public int getDecayHorizon(){
- return decayHorizonOption.getValue();
- }
+ public int getDecayHorizon() {
+ return decayHorizonOption.getValue();
+ }
- public double getDecayThreshold(){
- return decayThresholdOption.getValue();
- }
+ public double getDecayThreshold() {
+ return decayThresholdOption.getValue();
+ }
- public int getEvaluationFrequency(){
- return evaluationFrequencyOption.getValue();
- }
-
+ public int getEvaluationFrequency() {
+ return evaluationFrequencyOption.getValue();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/RandomRBFGeneratorEvents.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/RandomRBFGeneratorEvents.java
index a65d9a1..32a7f72 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/RandomRBFGeneratorEvents.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/clustering/RandomRBFGeneratorEvents.java
@@ -46,879 +46,875 @@
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
-
public class RandomRBFGeneratorEvents extends ClusteringStream {
- private transient Vector listeners;
+ private transient Vector listeners;
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public IntOption modelRandomSeedOption = new IntOption("modelRandomSeed",
- 'm', "Seed for random generation of model.", 1);
+ public IntOption modelRandomSeedOption = new IntOption("modelRandomSeed",
+ 'm', "Seed for random generation of model.", 1);
- public IntOption instanceRandomSeedOption = new IntOption(
- "instanceRandomSeed", 'i',
- "Seed for random generation of instances.", 5);
+ public IntOption instanceRandomSeedOption = new IntOption(
+ "instanceRandomSeed", 'i',
+ "Seed for random generation of instances.", 5);
- public IntOption numClusterOption = new IntOption("numCluster", 'K',
- "The average number of centroids in the model.", 5, 1, Integer.MAX_VALUE);
+ public IntOption numClusterOption = new IntOption("numCluster", 'K',
+ "The average number of centroids in the model.", 5, 1, Integer.MAX_VALUE);
- public IntOption numClusterRangeOption = new IntOption("numClusterRange", 'k',
- "Deviation of the number of centroids in the model.", 3, 0, Integer.MAX_VALUE);
+ public IntOption numClusterRangeOption = new IntOption("numClusterRange", 'k',
+ "Deviation of the number of centroids in the model.", 3, 0, Integer.MAX_VALUE);
- public FloatOption kernelRadiiOption = new FloatOption("kernelRadius", 'R',
- "The average radii of the centroids in the model.", 0.07, 0, 1);
+ public FloatOption kernelRadiiOption = new FloatOption("kernelRadius", 'R',
+ "The average radii of the centroids in the model.", 0.07, 0, 1);
- public FloatOption kernelRadiiRangeOption = new FloatOption("kernelRadiusRange", 'r',
- "Deviation of average radii of the centroids in the model.", 0, 0, 1);
+ public FloatOption kernelRadiiRangeOption = new FloatOption("kernelRadiusRange", 'r',
+ "Deviation of average radii of the centroids in the model.", 0, 0, 1);
- public FloatOption densityRangeOption = new FloatOption("densityRange", 'd',
- "Offset of the average weight a cluster has. Value of 0 means all cluster " +
- "contain the same amount of points.", 0, 0, 1);
+ public FloatOption densityRangeOption = new FloatOption("densityRange", 'd',
+ "Offset of the average weight a cluster has. Value of 0 means all cluster " +
+ "contain the same amount of points.", 0, 0, 1);
- public IntOption speedOption = new IntOption("speed", 'V',
- "Kernels move a predefined distance of 0.01 every X points", 500, 1, Integer.MAX_VALUE);
+ public IntOption speedOption = new IntOption("speed", 'V',
+ "Kernels move a predefined distance of 0.01 every X points", 500, 1, Integer.MAX_VALUE);
- public IntOption speedRangeOption = new IntOption("speedRange", 'v',
- "Speed/Velocity point offset", 0, 0, Integer.MAX_VALUE);
+ public IntOption speedRangeOption = new IntOption("speedRange", 'v',
+ "Speed/Velocity point offset", 0, 0, Integer.MAX_VALUE);
- public FloatOption noiseLevelOption = new FloatOption("noiseLevel", 'N',
- "Noise level", 0.1, 0, 1);
+ public FloatOption noiseLevelOption = new FloatOption("noiseLevel", 'N',
+ "Noise level", 0.1, 0, 1);
- public FlagOption noiseInClusterOption = new FlagOption("noiseInCluster", 'n',
- "Allow noise to be placed within a cluster");
+ public FlagOption noiseInClusterOption = new FlagOption("noiseInCluster", 'n',
+ "Allow noise to be placed within a cluster");
- public IntOption eventFrequencyOption = new IntOption("eventFrequency", 'E',
- "Event frequency. Enable at least one of the events below and set numClusterRange!", 30000, 0, Integer.MAX_VALUE);
+ public IntOption eventFrequencyOption = new IntOption("eventFrequency", 'E',
+ "Event frequency. Enable at least one of the events below and set numClusterRange!", 30000, 0, Integer.MAX_VALUE);
- public FlagOption eventMergeSplitOption = new FlagOption("eventMergeSplitOption", 'M',
- "Enable merging and splitting of clusters. Set eventFrequency and numClusterRange!");
+ public FlagOption eventMergeSplitOption = new FlagOption("eventMergeSplitOption", 'M',
+ "Enable merging and splitting of clusters. Set eventFrequency and numClusterRange!");
- public FlagOption eventDeleteCreateOption = new FlagOption("eventDeleteCreate", 'C',
- "Enable emering and disapperaing of clusters. Set eventFrequency and numClusterRange!");
+ public FlagOption eventDeleteCreateOption = new FlagOption("eventDeleteCreate", 'C',
+ "Enable emering and disapperaing of clusters. Set eventFrequency and numClusterRange!");
-
- private double merge_threshold = 0.7;
- private int kernelMovePointFrequency = 10;
- private double maxDistanceMoveThresholdByStep = 0.01;
- private int maxOverlapFitRuns = 50;
- private double eventFrequencyRange = 0;
+ private double merge_threshold = 0.7;
+ private int kernelMovePointFrequency = 10;
+ private double maxDistanceMoveThresholdByStep = 0.01;
+ private int maxOverlapFitRuns = 50;
+ private double eventFrequencyRange = 0;
- private boolean debug = false;
+ private boolean debug = false;
- private AutoExpandVector<GeneratorCluster> kernels;
- protected Random instanceRandom;
- protected InstancesHeader streamHeader;
- private int numGeneratedInstances;
- private int numActiveKernels;
- private int nextEventCounter;
- private int nextEventChoice = -1;
- private int clusterIdCounter;
- private GeneratorCluster mergeClusterA;
- private GeneratorCluster mergeClusterB;
- private boolean mergeKernelsOverlapping = false;
+ private AutoExpandVector<GeneratorCluster> kernels;
+ protected Random instanceRandom;
+ protected InstancesHeader streamHeader;
+ private int numGeneratedInstances;
+ private int numActiveKernels;
+ private int nextEventCounter;
+ private int nextEventChoice = -1;
+ private int clusterIdCounter;
+ private GeneratorCluster mergeClusterA;
+ private GeneratorCluster mergeClusterB;
+ private boolean mergeKernelsOverlapping = false;
+ private class GeneratorCluster implements Serializable {
+ // TODO: points is redundant to microclusterpoints, we need to come
+ // up with a good strategy that microclusters get updated and
+ // rebuild if needed. Idea: Sort microclusterpoints by timestamp and let
+ // microclusterdecay hold the timestamp for when the last point in a
+ // microcluster gets kicked, then we rebuild... or maybe not... could be
+ // same as searching for point to be kicked. more likely is we rebuild
+ // fewer times then insert.
+ private static final long serialVersionUID = -6301649898961112942L;
- private class GeneratorCluster implements Serializable{
- //TODO: points is redundant to microclusterpoints, we need to come
- //up with a good strategy that microclusters get updated and
- //rebuild if needed. Idea: Sort microclusterpoints by timestamp and let
- // microclusterdecay hold the timestamp for when the last point in a
- //microcluster gets kicked, then we rebuild... or maybe not... could be
- //same as searching for point to be kicked. more likely is we rebuild
- //fewer times then insert.
-
- private static final long serialVersionUID = -6301649898961112942L;
-
- SphereCluster generator;
- int kill = -1;
- boolean merging = false;
- double[] moveVector;
- int totalMovementSteps;
- int currentMovementSteps;
- boolean isSplitting = false;
+ SphereCluster generator;
+ int kill = -1;
+ boolean merging = false;
+ double[] moveVector;
+ int totalMovementSteps;
+ int currentMovementSteps;
+ boolean isSplitting = false;
- LinkedList<DataPoint> points = new LinkedList<DataPoint>();
- ArrayList<SphereCluster> microClusters = new ArrayList<SphereCluster>();
- ArrayList<ArrayList<DataPoint>> microClustersPoints = new ArrayList();
- ArrayList<Integer> microClustersDecay = new ArrayList();
+ LinkedList<DataPoint> points = new LinkedList<DataPoint>();
+ ArrayList<SphereCluster> microClusters = new ArrayList<SphereCluster>();
+ ArrayList<ArrayList<DataPoint>> microClustersPoints = new ArrayList();
+ ArrayList<Integer> microClustersDecay = new ArrayList();
+ public GeneratorCluster(int label) {
+ boolean outofbounds = true;
+ int tryCounter = 0;
+ while (outofbounds && tryCounter < maxOverlapFitRuns) {
+ tryCounter++;
+ outofbounds = false;
+ double[] center = new double[numAttsOption.getValue()];
+ double radius = kernelRadiiOption.getValue() + (instanceRandom.nextBoolean() ? -1 : 1)
+ * kernelRadiiRangeOption.getValue() * instanceRandom.nextDouble();
+ while (radius <= 0) {
+ radius = kernelRadiiOption.getValue() + (instanceRandom.nextBoolean() ? -1 : 1)
+ * kernelRadiiRangeOption.getValue() * instanceRandom.nextDouble();
+ }
+ for (int j = 0; j < numAttsOption.getValue(); j++) {
+ center[j] = instanceRandom.nextDouble();
+ if (center[j] - radius < 0 || center[j] + radius > 1) {
+ outofbounds = true;
+ break;
+ }
+ }
+ generator = new SphereCluster(center, radius);
+ }
+ if (tryCounter < maxOverlapFitRuns) {
+ generator.setId(label);
+ double avgWeight = 1.0 / numClusterOption.getValue();
+ double weight = avgWeight + (instanceRandom.nextBoolean() ? -1 : 1) * avgWeight * densityRangeOption.getValue()
+ * instanceRandom.nextDouble();
+ generator.setWeight(weight);
+ setDesitnation(null);
+ }
+ else {
+ generator = null;
+ kill = 0;
+ System.out.println("Tried " + maxOverlapFitRuns + " times to create kernel. Reduce average radii.");
+ }
+ }
- public GeneratorCluster(int label) {
- boolean outofbounds = true;
- int tryCounter = 0;
- while(outofbounds && tryCounter < maxOverlapFitRuns){
- tryCounter++;
- outofbounds = false;
- double[] center = new double [numAttsOption.getValue()];
- double radius = kernelRadiiOption.getValue()+(instanceRandom.nextBoolean()?-1:1)*kernelRadiiRangeOption.getValue()*instanceRandom.nextDouble();
- while(radius <= 0){
- radius = kernelRadiiOption.getValue()+(instanceRandom.nextBoolean()?-1:1)*kernelRadiiRangeOption.getValue()*instanceRandom.nextDouble();
- }
- for (int j = 0; j < numAttsOption.getValue(); j++) {
- center[j] = instanceRandom.nextDouble();
- if(center[j]- radius < 0 || center[j] + radius > 1){
- outofbounds = true;
- break;
- }
- }
- generator = new SphereCluster(center, radius);
- }
- if(tryCounter < maxOverlapFitRuns){
- generator.setId(label);
- double avgWeight = 1.0/numClusterOption.getValue();
- double weight = avgWeight + (instanceRandom.nextBoolean()?-1:1)*avgWeight*densityRangeOption.getValue()*instanceRandom.nextDouble();
- generator.setWeight(weight);
+ public GeneratorCluster(int label, SphereCluster cluster) {
+ this.generator = cluster;
+ cluster.setId(label);
+ setDesitnation(null);
+ }
+
+ public int getWorkID() {
+ for (int c = 0; c < kernels.size(); c++) {
+ if (kernels.get(c).equals(this))
+ return c;
+ }
+ return -1;
+ }
+
+ private void updateKernel() {
+ if (kill == 0) {
+ kernels.remove(this);
+ }
+ if (kill > 0) {
+ kill--;
+ }
+ // we could be lot more precise if we would keep track of timestamps of
+ // points
+ // then we could remove all old points and rebuild the cluster on up to
+ // date point base
+ // BUT worse the effort??? so far we just want to avoid overlap with this,
+ // so its more
+ // konservative as needed. Only needs to change when we need a thighter
+ // representation
+ for (int m = 0; m < microClusters.size(); m++) {
+ if (numGeneratedInstances - microClustersDecay.get(m) > decayHorizonOption.getValue()) {
+ microClusters.remove(m);
+ microClustersPoints.remove(m);
+ microClustersDecay.remove(m);
+ }
+ }
+
+ if (!points.isEmpty()
+ && numGeneratedInstances - points.getFirst().getTimestamp() >= decayHorizonOption.getValue()) {
+ // if(debug)
+ // System.out.println("Cleaning up macro cluster "+generator.getId());
+ points.removeFirst();
+ }
+
+ }
+
+ private void addInstance(Instance instance) {
+ DataPoint point = new DataPoint(instance, numGeneratedInstances);
+ points.add(point);
+
+ int minMicroIndex = -1;
+ double minHullDist = Double.MAX_VALUE;
+ boolean inserted = false;
+ // we favour more recently build clusters so we can remove earlier cluster
+ // sooner
+ for (int m = microClusters.size() - 1; m >= 0; m--) {
+ SphereCluster micro = microClusters.get(m);
+ double hulldist = micro.getCenterDistance(point) - micro.getRadius();
+ // point fits into existing cluster
+ if (hulldist <= 0) {
+ microClustersPoints.get(m).add(point);
+ microClustersDecay.set(m, numGeneratedInstances);
+ inserted = true;
+ break;
+ }
+ // if not, check if its at least the closest cluster
+ else {
+ if (hulldist < minHullDist) {
+ minMicroIndex = m;
+ minHullDist = hulldist;
+ }
+ }
+ }
+ // Reseting index choice for alternative cluster building
+ int alt = 1;
+ if (alt == 1)
+ minMicroIndex = -1;
+ if (!inserted) {
+ // add to closest cluster and expand cluster
+ if (minMicroIndex != -1) {
+ microClustersPoints.get(minMicroIndex).add(point);
+ // we should keep the miniball instances and just check in
+ // new points instead of rebuilding the whole thing
+ SphereCluster s = new SphereCluster(microClustersPoints.get(minMicroIndex), numAttsOption.getValue());
+ // check if current microcluster is bigger then generating cluster
+ if (s.getRadius() > generator.getRadius()) {
+ // remove previously added point
+ microClustersPoints.get(minMicroIndex).remove(microClustersPoints.get(minMicroIndex).size() - 1);
+ minMicroIndex = -1;
+ }
+ else {
+ microClusters.set(minMicroIndex, s);
+ microClustersDecay.set(minMicroIndex, numGeneratedInstances);
+ }
+ }
+ // minMicroIndex might have been reset above
+ // create new micro cluster
+ if (minMicroIndex == -1) {
+ ArrayList<DataPoint> microPoints = new ArrayList<DataPoint>();
+ microPoints.add(point);
+ SphereCluster s;
+ if (alt == 0)
+ s = new SphereCluster(microPoints, numAttsOption.getValue());
+ else
+ s = new SphereCluster(generator.getCenter(), generator.getRadius(), 1);
+
+ microClusters.add(s);
+ microClustersPoints.add(microPoints);
+ microClustersDecay.add(numGeneratedInstances);
+ int id = 0;
+ while (id < kernels.size()) {
+ if (kernels.get(id) == this)
+ break;
+ id++;
+ }
+ s.setGroundTruth(id);
+ }
+ }
+
+ }
+
+ private void move() {
+ if (currentMovementSteps < totalMovementSteps) {
+ currentMovementSteps++;
+ if (moveVector == null) {
+ return;
+ }
+ else {
+ double[] center = generator.getCenter();
+ boolean outofbounds = true;
+ while (outofbounds) {
+ double radius = generator.getRadius();
+ outofbounds = false;
+ center = generator.getCenter();
+ for (int d = 0; d < center.length; d++) {
+ center[d] += moveVector[d];
+ if (center[d] - radius < 0 || center[d] + radius > 1) {
+ outofbounds = true;
setDesitnation(null);
+ break;
+ }
}
- else{
- generator = null;
- kill = 0;
- System.out.println("Tried "+maxOverlapFitRuns+" times to create kernel. Reduce average radii." );
- }
+ }
+ generator.setCenter(center);
}
-
- public GeneratorCluster(int label, SphereCluster cluster) {
- this.generator = cluster;
- cluster.setId(label);
- setDesitnation(null);
+ }
+ else {
+ if (!merging) {
+ setDesitnation(null);
+ isSplitting = false;
}
-
- public int getWorkID(){
- for(int c = 0; c < kernels.size(); c++){
- if(kernels.get(c).equals(this))
- return c;
- }
- return -1;
- }
-
- private void updateKernel(){
- if(kill == 0){
- kernels.remove(this);
- }
- if(kill > 0){
- kill--;
- }
- //we could be lot more precise if we would keep track of timestamps of points
- //then we could remove all old points and rebuild the cluster on up to date point base
- //BUT worse the effort??? so far we just want to avoid overlap with this, so its more
- //konservative as needed. Only needs to change when we need a thighter representation
- for (int m = 0; m < microClusters.size(); m++) {
- if(numGeneratedInstances-microClustersDecay.get(m) > decayHorizonOption.getValue()){
- microClusters.remove(m);
- microClustersPoints.remove(m);
- microClustersDecay.remove(m);
- }
- }
-
- if(!points.isEmpty() && numGeneratedInstances-points.getFirst().getTimestamp() >= decayHorizonOption.getValue()){
-// if(debug)
-// System.out.println("Cleaning up macro cluster "+generator.getId());
- points.removeFirst();
- }
-
- }
-
- private void addInstance(Instance instance){
- DataPoint point = new DataPoint(instance, numGeneratedInstances);
- points.add(point);
-
- int minMicroIndex = -1;
- double minHullDist = Double.MAX_VALUE;
- boolean inserted = false;
- //we favour more recently build clusters so we can remove earlier cluster sooner
- for (int m = microClusters.size()-1; m >=0 ; m--) {
- SphereCluster micro = microClusters.get(m);
- double hulldist = micro.getCenterDistance(point)-micro.getRadius();
- //point fits into existing cluster
- if(hulldist <= 0){
- microClustersPoints.get(m).add(point);
- microClustersDecay.set(m, numGeneratedInstances);
- inserted = true;
- break;
- }
- //if not, check if its at least the closest cluster
- else{
- if(hulldist < minHullDist){
- minMicroIndex = m;
- minHullDist = hulldist;
- }
- }
- }
- //Reseting index choice for alternative cluster building
- int alt = 1;
- if(alt == 1)
- minMicroIndex = -1;
- if(!inserted){
- //add to closest cluster and expand cluster
- if(minMicroIndex!=-1){
- microClustersPoints.get(minMicroIndex).add(point);
- //we should keep the miniball instances and just check in
- //new points instead of rebuilding the whole thing
- SphereCluster s = new SphereCluster(microClustersPoints.get(minMicroIndex),numAttsOption.getValue());
- //check if current microcluster is bigger then generating cluster
- if(s.getRadius() > generator.getRadius()){
- //remove previously added point
- microClustersPoints.get(minMicroIndex).remove(microClustersPoints.get(minMicroIndex).size()-1);
- minMicroIndex = -1;
- }
- else{
- microClusters.set(minMicroIndex, s);
- microClustersDecay.set(minMicroIndex, numGeneratedInstances);
- }
- }
- //minMicroIndex might have been reset above
- //create new micro cluster
- if(minMicroIndex == -1){
- ArrayList<DataPoint> microPoints = new ArrayList<DataPoint>();
- microPoints.add(point);
- SphereCluster s;
- if(alt == 0)
- s = new SphereCluster(microPoints,numAttsOption.getValue());
- else
- s = new SphereCluster(generator.getCenter(),generator.getRadius(),1);
-
- microClusters.add(s);
- microClustersPoints.add(microPoints);
- microClustersDecay.add(numGeneratedInstances);
- int id = 0;
- while(id < kernels.size()){
- if(kernels.get(id) == this)
- break;
- id++;
- }
- s.setGroundTruth(id);
- }
- }
-
- }
-
-
- private void move(){
- if(currentMovementSteps < totalMovementSteps){
- currentMovementSteps++;
- if( moveVector == null){
- return;
- }
- else{
- double[] center = generator.getCenter();
- boolean outofbounds = true;
- while(outofbounds){
- double radius = generator.getRadius();
- outofbounds = false;
- center = generator.getCenter();
- for ( int d = 0; d < center.length; d++ ) {
- center[d]+= moveVector[d];
- if(center[d]- radius < 0 || center[d] + radius > 1){
- outofbounds = true;
- setDesitnation(null);
- break;
- }
- }
- }
- generator.setCenter(center);
- }
- }
- else{
- if(!merging){
- setDesitnation(null);
- isSplitting = false;
- }
- }
- }
-
- void setDesitnation(double[] destination){
-
- if(destination == null){
- destination = new double [numAttsOption.getValue()];
- for (int j = 0; j < numAttsOption.getValue(); j++) {
- destination[j] = instanceRandom.nextDouble();
- }
- }
- double[] center = generator.getCenter();
- int dim = center.length;
-
- double[] v = new double[dim];
-
- for ( int d = 0; d < dim; d++ ) {
- v[d]=destination[d]-center[d];
- }
- setMoveVector(v);
- }
-
- void setMoveVector(double[] vector){
- //we are ignoring the steps, otherwise we have to change
- //speed of the kernels, do we want that?
- moveVector = vector;
- int speedInPoints = speedOption.getValue();
- if(speedRangeOption.getValue() > 0)
- speedInPoints +=(instanceRandom.nextBoolean()?-1:1)*instanceRandom.nextInt(speedRangeOption.getValue());
- if(speedInPoints < 1) speedInPoints = speedOption.getValue();
-
-
- double length = 0;
- for ( int d = 0; d < moveVector.length; d++ ) {
- length+=Math.pow(vector[d],2);
- }
- length = Math.sqrt(length);
-
- totalMovementSteps = (int)(length/(maxDistanceMoveThresholdByStep*kernelMovePointFrequency)*speedInPoints);
- for ( int d = 0; d < moveVector.length; d++ ) {
- moveVector[d]/=(double)totalMovementSteps;
- }
-
-
- currentMovementSteps = 0;
-// if(debug){
-// System.out.println("Setting new direction for C"+generator.getId()+": distance "
-// +length+" in "+totalMovementSteps+" steps");
-// }
- }
-
- private String tryMerging(GeneratorCluster merge){
- String message = "";
- double overlapDegree = generator.overlapRadiusDegree(merge.generator);
- if(overlapDegree > merge_threshold){
- SphereCluster mcluster = merge.generator;
- double radius = Math.max(generator.getRadius(), mcluster.getRadius());
- generator.combine(mcluster);
-
-// //adjust radius, get bigger and bigger with high dim data
- generator.setRadius(radius);
-// double[] center = generator.getCenter();
-// double[] mcenter = mcluster.getCenter();
-// double weight = generator.getWeight();
-// double mweight = generator.getWeight();
-//// for (int i = 0; i < center.length; i++) {
-//// center[i] = (center[i] * weight + mcenter[i] * mweight) / (mweight + weight);
-//// }
-// generator.setWeight(weight + mweight);
- message = "Clusters merging: "+mergeClusterB.generator.getId()+" into "+mergeClusterA.generator.getId();
-
- //clean up and restet merging stuff
- //mark kernel so it gets killed when it doesn't contain any more instances
- merge.kill = decayHorizonOption.getValue();
- //set weight to 0 so no new instances will be created in the cluster
- mcluster.setWeight(0.0);
- normalizeWeights();
- numActiveKernels--;
- mergeClusterB = mergeClusterA = null;
- merging = false;
- mergeKernelsOverlapping = false;
- }
- else{
- if(overlapDegree > 0 && !mergeKernelsOverlapping){
- mergeKernelsOverlapping = true;
- message = "Merge overlapping started";
- }
- }
- return message;
- }
-
- private String splitKernel(){
- isSplitting = true;
- //todo radius range
- double radius = kernelRadiiOption.getValue();
- double avgWeight = 1.0/numClusterOption.getValue();
- double weight = avgWeight + avgWeight*densityRangeOption.getValue()*instanceRandom.nextDouble();
- SphereCluster spcluster = null;
-
- double[] center = generator.getCenter();
- spcluster = new SphereCluster(center, radius, weight);
-
- if(spcluster !=null){
- GeneratorCluster gc = new GeneratorCluster(clusterIdCounter++, spcluster);
- gc.isSplitting = true;
- kernels.add(gc);
- normalizeWeights();
- numActiveKernels++;
- return "Split from Kernel "+generator.getId();
- }
- else{
- System.out.println("Tried to split new kernel from C"+generator.getId()+
- ". Not enough room for new cluster, decrease average radii, number of clusters or enable overlap.");
- return "";
- }
- }
-
- private String fadeOut(){
- kill = decayHorizonOption.getValue();
- generator.setWeight(0.0);
- numActiveKernels--;
- normalizeWeights();
- return "Fading out C"+generator.getId();
- }
-
-
+ }
}
- public RandomRBFGeneratorEvents() {
- noiseInClusterOption.set();
-// eventDeleteCreateOption.set();
-// eventMergeSplitOption.set();
- }
+ void setDesitnation(double[] destination) {
- public InstancesHeader getHeader() {
- return streamHeader;
- }
-
- public long estimatedRemainingInstances() {
- return -1;
- }
-
- public boolean hasMoreInstances() {
- return true;
- }
-
- public boolean isRestartable() {
- return true;
- }
-
- @Override
- public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- monitor.setCurrentActivity("Preparing random RBF...", -1.0);
- generateHeader();
- restart();
- }
-
- public void restart() {
- instanceRandom = new Random(instanceRandomSeedOption.getValue());
- nextEventCounter = eventFrequencyOption.getValue();
- nextEventChoice = getNextEvent();
- numActiveKernels = 0;
- numGeneratedInstances = 0;
- clusterIdCounter = 0;
- mergeClusterA = mergeClusterB = null;
- kernels = new AutoExpandVector<GeneratorCluster>();
-
- initKernels();
- }
-
- protected void generateHeader() { // 2013/06/02: Noise label
- ArrayList<Attribute> attributes = new ArrayList<Attribute>();
- for (int i = 0; i < this.numAttsOption.getValue(); i++) {
- attributes.add(new Attribute("att" + (i + 1)));
+ if (destination == null) {
+ destination = new double[numAttsOption.getValue()];
+ for (int j = 0; j < numAttsOption.getValue(); j++) {
+ destination[j] = instanceRandom.nextDouble();
}
-
- ArrayList<String> classLabels = new ArrayList<String>();
- for (int i = 0; i < this.numClusterOption.getValue(); i++) {
- classLabels.add("class" + (i + 1));
- }
- if (noiseLevelOption.getValue() > 0) classLabels.add("noise"); // The last label = "noise"
-
- attributes.add(new Attribute("class", classLabels));
- streamHeader = new InstancesHeader(new Instances(getCLICreationString(InstanceStream.class), attributes, 0));
- streamHeader.setClassIndex(streamHeader.numAttributes() - 1);
+ }
+ double[] center = generator.getCenter();
+ int dim = center.length;
+
+ double[] v = new double[dim];
+
+ for (int d = 0; d < dim; d++) {
+ v[d] = destination[d] - center[d];
+ }
+ setMoveVector(v);
}
-
- protected void initKernels() {
- for (int i = 0; i < numClusterOption.getValue(); i++) {
- kernels.add(new GeneratorCluster(clusterIdCounter));
- numActiveKernels++;
- clusterIdCounter++;
- }
+ void setMoveVector(double[] vector) {
+ // we are ignoring the steps, otherwise we have to change
+ // speed of the kernels, do we want that?
+ moveVector = vector;
+ int speedInPoints = speedOption.getValue();
+ if (speedRangeOption.getValue() > 0)
+ speedInPoints += (instanceRandom.nextBoolean() ? -1 : 1) * instanceRandom.nextInt(speedRangeOption.getValue());
+ if (speedInPoints < 1)
+ speedInPoints = speedOption.getValue();
+
+ double length = 0;
+ for (int d = 0; d < moveVector.length; d++) {
+ length += Math.pow(vector[d], 2);
+ }
+ length = Math.sqrt(length);
+
+ totalMovementSteps = (int) (length / (maxDistanceMoveThresholdByStep * kernelMovePointFrequency) * speedInPoints);
+ for (int d = 0; d < moveVector.length; d++) {
+ moveVector[d] /= (double) totalMovementSteps;
+ }
+
+ currentMovementSteps = 0;
+ // if(debug){
+ // System.out.println("Setting new direction for C"+generator.getId()+": distance "
+ // +length+" in "+totalMovementSteps+" steps");
+ // }
+ }
+
+ private String tryMerging(GeneratorCluster merge) {
+ String message = "";
+ double overlapDegree = generator.overlapRadiusDegree(merge.generator);
+ if (overlapDegree > merge_threshold) {
+ SphereCluster mcluster = merge.generator;
+ double radius = Math.max(generator.getRadius(), mcluster.getRadius());
+ generator.combine(mcluster);
+
+ // //adjust radius, get bigger and bigger with high dim data
+ generator.setRadius(radius);
+ // double[] center = generator.getCenter();
+ // double[] mcenter = mcluster.getCenter();
+ // double weight = generator.getWeight();
+ // double mweight = generator.getWeight();
+ // // for (int i = 0; i < center.length; i++) {
+ // // center[i] = (center[i] * weight + mcenter[i] * mweight) / (mweight
+ // + weight);
+ // // }
+ // generator.setWeight(weight + mweight);
+ message = "Clusters merging: " + mergeClusterB.generator.getId() + " into " + mergeClusterA.generator.getId();
+
+ // clean up and restet merging stuff
+ // mark kernel so it gets killed when it doesn't contain any more
+ // instances
+ merge.kill = decayHorizonOption.getValue();
+ // set weight to 0 so no new instances will be created in the cluster
+ mcluster.setWeight(0.0);
normalizeWeights();
+ numActiveKernels--;
+ mergeClusterB = mergeClusterA = null;
+ merging = false;
+ mergeKernelsOverlapping = false;
+ }
+ else {
+ if (overlapDegree > 0 && !mergeKernelsOverlapping) {
+ mergeKernelsOverlapping = true;
+ message = "Merge overlapping started";
+ }
+ }
+ return message;
}
- public InstanceExample nextInstance() {
- numGeneratedInstances++;
- eventScheduler();
+ private String splitKernel() {
+ isSplitting = true;
+ // todo radius range
+ double radius = kernelRadiiOption.getValue();
+ double avgWeight = 1.0 / numClusterOption.getValue();
+ double weight = avgWeight + avgWeight * densityRangeOption.getValue() * instanceRandom.nextDouble();
+ SphereCluster spcluster = null;
- //make room for the classlabel
- double[] values_new = new double [numAttsOption.getValue()]; //+1
- double[] values = null;
- int clusterChoice = -1;
+ double[] center = generator.getCenter();
+ spcluster = new SphereCluster(center, radius, weight);
- if(instanceRandom.nextDouble() > noiseLevelOption.getValue()){
- clusterChoice = chooseWeightedElement();
- values = kernels.get(clusterChoice).generator.sample(instanceRandom).toDoubleArray();
- }
- else{
- //get ranodm noise point
- values = getNoisePoint();
- }
-
- if(Double.isNaN(values[0])){
- System.out.println("Instance corrupted:"+numGeneratedInstances);
- }
- System.arraycopy(values, 0, values_new, 0, values.length);
-
- Instance inst = new DenseInstance(1.0, values_new);
- inst.setDataset(getHeader());
- if(clusterChoice == -1){
- // 2013/06/02 (Yunsu Kim)
- // Noise instance has the last class value instead of "-1"
- // Preventing ArrayIndexOutOfBoundsException in WriteStreamToARFFFile
- inst.setClassValue(numClusterOption.getValue());
- }
- else{
- inst.setClassValue(kernels.get(clusterChoice).generator.getId());
- //Do we need micro cluster representation if have overlapping clusters?
- //if(!overlappingOption.isSet())
- kernels.get(clusterChoice).addInstance(inst);
- }
-// System.out.println(numGeneratedInstances+": Overlap is"+updateOverlaps());
-
- return new InstanceExample(inst);
- }
-
-
- public Clustering getGeneratingClusters(){
- Clustering clustering = new Clustering();
- for (int c = 0; c < kernels.size(); c++) {
- clustering.add(kernels.get(c).generator);
- }
- return clustering;
- }
-
- public Clustering getMicroClustering(){
- Clustering clustering = new Clustering();
- int id = 0;
-
- for (int c = 0; c < kernels.size(); c++) {
- for (int m = 0; m < kernels.get(c).microClusters.size(); m++) {
- kernels.get(c).microClusters.get(m).setId(id);
- kernels.get(c).microClusters.get(m).setGroundTruth(kernels.get(c).generator.getId());
- clustering.add(kernels.get(c).microClusters.get(m));
- id++;
- }
- }
-
- //System.out.println("numMicroKernels "+clustering.size());
- return clustering;
- }
-
-/**************************** EVENTS ******************************************/
- private void eventScheduler(){
-
- for ( int i = 0; i < kernels.size(); i++ ) {
- kernels.get(i).updateKernel();
- }
-
- nextEventCounter--;
- //only move kernels every 10 points, performance reasons????
- //should this be randomized as well???
- if(nextEventCounter%kernelMovePointFrequency == 0){
- //move kernels
- for ( int i = 0; i < kernels.size(); i++ ) {
- kernels.get(i).move();
- //overlapControl();
- }
- }
-
-
- if(eventFrequencyOption.getValue() == 0){
- return;
- }
-
- String type ="";
- String message ="";
- boolean eventFinished = false;
- switch(nextEventChoice){
- case 0:
- if(numActiveKernels > 1 && numActiveKernels > numClusterOption.getValue() - numClusterRangeOption.getValue()){
- message = mergeKernels(nextEventCounter);
- type = "Merge";
- }
- if(mergeClusterA==null && mergeClusterB==null && message.startsWith("Clusters merging")){
- eventFinished = true;
- }
- break;
- case 1:
- if(nextEventCounter<=0){
- if(numActiveKernels < numClusterOption.getValue() + numClusterRangeOption.getValue()){
- type = "Split";
- message = splitKernel();
- }
- eventFinished = true;
- }
- break;
- case 2:
- if(nextEventCounter<=0){
- if(numActiveKernels > 1 && numActiveKernels > numClusterOption.getValue() - numClusterRangeOption.getValue()){
- message = fadeOut();
- type = "Delete";
- }
- eventFinished = true;
- }
- break;
- case 3:
- if(nextEventCounter<=0){
- if(numActiveKernels < numClusterOption.getValue() + numClusterRangeOption.getValue()){
- message = fadeIn();
- type = "Create";
- }
- eventFinished = true;
- }
- break;
-
- }
- if (eventFinished){
- nextEventCounter = (int)(eventFrequencyOption.getValue()+(instanceRandom.nextBoolean()?-1:1)*eventFrequencyOption.getValue()*eventFrequencyRange*instanceRandom.nextDouble());
- nextEventChoice = getNextEvent();
- //System.out.println("Next event choice: "+nextEventChoice);
- }
- if(!message.isEmpty()){
- message+=" (numKernels = "+numActiveKernels+" at "+numGeneratedInstances+")";
- if(!type.equals("Merge") || message.startsWith("Clusters merging"))
- fireClusterChange(numGeneratedInstances, type, message);
- }
- }
-
- private int getNextEvent() {
- int choice = -1;
- boolean lowerLimit = numActiveKernels <= numClusterOption.getValue() - numClusterRangeOption.getValue();
- boolean upperLimit = numActiveKernels >= numClusterOption.getValue() + numClusterRangeOption.getValue();
-
- if(!lowerLimit || !upperLimit){
- int mode = -1;
- if(eventDeleteCreateOption.isSet() && eventMergeSplitOption.isSet()){
- mode = instanceRandom.nextInt(2);
- }
-
- if(mode==0 || (mode==-1 && eventMergeSplitOption.isSet())){
- //have we reached a limit? if not free choice
- if(!lowerLimit && !upperLimit)
- choice = instanceRandom.nextInt(2);
- else
- //we have a limit. if lower limit, choose split
- if(lowerLimit)
- choice = 1;
- //otherwise we reached upper level, choose merge
- else
- choice = 0;
- }
-
- if(mode==1 || (mode==-1 && eventDeleteCreateOption.isSet())){
- //have we reached a limit? if not free choice
- if(!lowerLimit && !upperLimit)
- choice = instanceRandom.nextInt(2)+2;
- else
- //we have a limit. if lower limit, choose create
- if(lowerLimit)
- choice = 3;
- //otherwise we reached upper level, choose delete
- else
- choice = 2;
- }
- }
-
-
- return choice;
- }
-
- private String fadeOut(){
- int id = instanceRandom.nextInt(kernels.size());
- while(kernels.get(id).kill!=-1)
- id = instanceRandom.nextInt(kernels.size());
-
- String message = kernels.get(id).fadeOut();
- return message;
- }
-
- private String fadeIn(){
- GeneratorCluster gc = new GeneratorCluster(clusterIdCounter++);
- kernels.add(gc);
- numActiveKernels++;
- normalizeWeights();
- return "Creating new cluster";
- }
-
-
- private String changeWeight(boolean increase){
- double changeRate = 0.1;
- int id = instanceRandom.nextInt(kernels.size());
- while(kernels.get(id).kill!=-1)
- id = instanceRandom.nextInt(kernels.size());
-
- int sign = 1;
- if(!increase)
- sign = -1;
- double weight_old = kernels.get(id).generator.getWeight();
- double delta = sign*numActiveKernels*instanceRandom.nextDouble()*changeRate;
- kernels.get(id).generator.setWeight(weight_old+delta);
-
+ if (spcluster != null) {
+ GeneratorCluster gc = new GeneratorCluster(clusterIdCounter++, spcluster);
+ gc.isSplitting = true;
+ kernels.add(gc);
normalizeWeights();
-
- String message;
- if(increase)
- message = "Increase ";
- else
- message = "Decrease ";
- message+=" weight on Cluster "+id+" from "+weight_old+" to "+(weight_old+delta);
- return message;
-
-
- }
-
- private String changeRadius(boolean increase){
- double maxChangeRate = 0.1;
- int id = instanceRandom.nextInt(kernels.size());
- while(kernels.get(id).kill!=-1)
- id = instanceRandom.nextInt(kernels.size());
-
- int sign = 1;
- if(!increase)
- sign = -1;
-
- double r_old = kernels.get(id).generator.getRadius();
- double r_new =r_old+sign*r_old*instanceRandom.nextDouble()*maxChangeRate;
- if(r_new >= 0.5) return "Radius to big";
- kernels.get(id).generator.setRadius(r_new);
-
- String message;
- if(increase)
- message = "Increase ";
- else
- message = "Decrease ";
- message+=" radius on Cluster "+id+" from "+r_old+" to "+r_new;
- return message;
- }
-
- private String splitKernel(){
- int id = instanceRandom.nextInt(kernels.size());
- while(kernels.get(id).kill!=-1)
- id = instanceRandom.nextInt(kernels.size());
-
- String message = kernels.get(id).splitKernel();
-
- return message;
- }
-
- private String mergeKernels(int steps){
- if(numActiveKernels >1 && ((mergeClusterA == null && mergeClusterB == null))){
-
- //choose clusters to merge
- double diseredDist = steps / speedOption.getValue() * maxDistanceMoveThresholdByStep;
- double minDist = Double.MAX_VALUE;
-// System.out.println("DisredDist:"+(2*diseredDist));
- for(int i = 0; i < kernels.size(); i++){
- for(int j = 0; j < i; j++){
- if(kernels.get(i).kill!=-1 || kernels.get(j).kill!=-1){
- continue;
- }
- else{
- double kernelDist = kernels.get(i).generator.getCenterDistance(kernels.get(j).generator);
- double d = kernelDist-2*diseredDist;
-// System.out.println("Dist:"+i+" / "+j+" "+d);
- if(Math.abs(d) < minDist &&
- (minDist != Double.MAX_VALUE || d>0 || Math.abs(d) < 0.001)){
- minDist = Math.abs(d);
- mergeClusterA = kernels.get(i);
- mergeClusterB = kernels.get(j);
- }
- }
- }
- }
-
- if(mergeClusterA!=null && mergeClusterB!=null){
- double[] merge_point = mergeClusterA.generator.getCenter();
- double[] v = mergeClusterA.generator.getDistanceVector(mergeClusterB.generator);
- for (int i = 0; i < v.length; i++) {
- merge_point[i]= merge_point[i]+v[i]*0.5;
- }
-
- mergeClusterA.merging = true;
- mergeClusterB.merging = true;
- mergeClusterA.setDesitnation(merge_point);
- mergeClusterB.setDesitnation(merge_point);
-
- if(debug){
- System.out.println("Center1"+Arrays.toString(mergeClusterA.generator.getCenter()));
- System.out.println("Center2"+Arrays.toString(mergeClusterB.generator.getCenter()));
- System.out.println("Vector"+Arrays.toString(v));
-
- System.out.println("Try to merge cluster "+mergeClusterA.generator.getId()+
- " into "+mergeClusterB.generator.getId()+
- " at "+Arrays.toString(merge_point)+
- " time "+numGeneratedInstances);
- }
- return "Init merge";
- }
- }
-
- if(mergeClusterA != null && mergeClusterB != null){
-
- //movekernels will move the kernels close to each other,
- //we just need to check and merge here if they are close enough
- return mergeClusterA.tryMerging(mergeClusterB);
- }
-
+ numActiveKernels++;
+ return "Split from Kernel " + generator.getId();
+ }
+ else {
+ System.out.println("Tried to split new kernel from C" + generator.getId() +
+ ". Not enough room for new cluster, decrease average radii, number of clusters or enable overlap.");
return "";
+ }
}
-
-
-
-/************************* TOOLS **************************************/
-
- public void getDescription(StringBuilder sb, int indent) {
-
+ private String fadeOut() {
+ kill = decayHorizonOption.getValue();
+ generator.setWeight(0.0);
+ numActiveKernels--;
+ normalizeWeights();
+ return "Fading out C" + generator.getId();
}
- private double[] getNoisePoint(){
- double [] sample = new double [numAttsOption.getValue()];
- boolean incluster = true;
- int counter = 20;
- while(incluster){
- for (int j = 0; j < numAttsOption.getValue(); j++) {
- sample[j] = instanceRandom.nextDouble();
+ }
+
+ public RandomRBFGeneratorEvents() {
+ noiseInClusterOption.set();
+ // eventDeleteCreateOption.set();
+ // eventMergeSplitOption.set();
+ }
+
+ public InstancesHeader getHeader() {
+ return streamHeader;
+ }
+
+ public long estimatedRemainingInstances() {
+ return -1;
+ }
+
+ public boolean hasMoreInstances() {
+ return true;
+ }
+
+ public boolean isRestartable() {
+ return true;
+ }
+
+ @Override
+ public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ monitor.setCurrentActivity("Preparing random RBF...", -1.0);
+ generateHeader();
+ restart();
+ }
+
+ public void restart() {
+ instanceRandom = new Random(instanceRandomSeedOption.getValue());
+ nextEventCounter = eventFrequencyOption.getValue();
+ nextEventChoice = getNextEvent();
+ numActiveKernels = 0;
+ numGeneratedInstances = 0;
+ clusterIdCounter = 0;
+ mergeClusterA = mergeClusterB = null;
+ kernels = new AutoExpandVector<GeneratorCluster>();
+
+ initKernels();
+ }
+
+ protected void generateHeader() { // 2013/06/02: Noise label
+ ArrayList<Attribute> attributes = new ArrayList<Attribute>();
+ for (int i = 0; i < this.numAttsOption.getValue(); i++) {
+ attributes.add(new Attribute("att" + (i + 1)));
+ }
+
+ ArrayList<String> classLabels = new ArrayList<String>();
+ for (int i = 0; i < this.numClusterOption.getValue(); i++) {
+ classLabels.add("class" + (i + 1));
+ }
+ if (noiseLevelOption.getValue() > 0)
+ classLabels.add("noise"); // The last label = "noise"
+
+ attributes.add(new Attribute("class", classLabels));
+ streamHeader = new InstancesHeader(new Instances(getCLICreationString(InstanceStream.class), attributes, 0));
+ streamHeader.setClassIndex(streamHeader.numAttributes() - 1);
+ }
+
+ protected void initKernels() {
+ for (int i = 0; i < numClusterOption.getValue(); i++) {
+ kernels.add(new GeneratorCluster(clusterIdCounter));
+ numActiveKernels++;
+ clusterIdCounter++;
+ }
+ normalizeWeights();
+ }
+
+ public InstanceExample nextInstance() {
+ numGeneratedInstances++;
+ eventScheduler();
+
+ // make room for the classlabel
+ double[] values_new = new double[numAttsOption.getValue()]; // +1
+ double[] values = null;
+ int clusterChoice = -1;
+
+ if (instanceRandom.nextDouble() > noiseLevelOption.getValue()) {
+ clusterChoice = chooseWeightedElement();
+ values = kernels.get(clusterChoice).generator.sample(instanceRandom).toDoubleArray();
+ }
+ else {
+ // get ranodm noise point
+ values = getNoisePoint();
+ }
+
+ if (Double.isNaN(values[0])) {
+ System.out.println("Instance corrupted:" + numGeneratedInstances);
+ }
+ System.arraycopy(values, 0, values_new, 0, values.length);
+
+ Instance inst = new DenseInstance(1.0, values_new);
+ inst.setDataset(getHeader());
+ if (clusterChoice == -1) {
+ // 2013/06/02 (Yunsu Kim)
+ // Noise instance has the last class value instead of "-1"
+ // Preventing ArrayIndexOutOfBoundsException in WriteStreamToARFFFile
+ inst.setClassValue(numClusterOption.getValue());
+ }
+ else {
+ inst.setClassValue(kernels.get(clusterChoice).generator.getId());
+ // Do we need micro cluster representation if have overlapping clusters?
+ // if(!overlappingOption.isSet())
+ kernels.get(clusterChoice).addInstance(inst);
+ }
+ // System.out.println(numGeneratedInstances+": Overlap is"+updateOverlaps());
+
+ return new InstanceExample(inst);
+ }
+
+ public Clustering getGeneratingClusters() {
+ Clustering clustering = new Clustering();
+ for (int c = 0; c < kernels.size(); c++) {
+ clustering.add(kernels.get(c).generator);
+ }
+ return clustering;
+ }
+
+ public Clustering getMicroClustering() {
+ Clustering clustering = new Clustering();
+ int id = 0;
+
+ for (int c = 0; c < kernels.size(); c++) {
+ for (int m = 0; m < kernels.get(c).microClusters.size(); m++) {
+ kernels.get(c).microClusters.get(m).setId(id);
+ kernels.get(c).microClusters.get(m).setGroundTruth(kernels.get(c).generator.getId());
+ clustering.add(kernels.get(c).microClusters.get(m));
+ id++;
+ }
+ }
+
+ // System.out.println("numMicroKernels "+clustering.size());
+ return clustering;
+ }
+
+ /**************************** EVENTS ******************************************/
+ private void eventScheduler() {
+
+ for (int i = 0; i < kernels.size(); i++) {
+ kernels.get(i).updateKernel();
+ }
+
+ nextEventCounter--;
+ // only move kernels every 10 points, performance reasons????
+ // should this be randomized as well???
+ if (nextEventCounter % kernelMovePointFrequency == 0) {
+ // move kernels
+ for (int i = 0; i < kernels.size(); i++) {
+ kernels.get(i).move();
+ // overlapControl();
+ }
+ }
+
+ if (eventFrequencyOption.getValue() == 0) {
+ return;
+ }
+
+ String type = "";
+ String message = "";
+ boolean eventFinished = false;
+ switch (nextEventChoice) {
+ case 0:
+ if (numActiveKernels > 1 && numActiveKernels > numClusterOption.getValue() - numClusterRangeOption.getValue()) {
+ message = mergeKernels(nextEventCounter);
+ type = "Merge";
+ }
+ if (mergeClusterA == null && mergeClusterB == null && message.startsWith("Clusters merging")) {
+ eventFinished = true;
+ }
+ break;
+ case 1:
+ if (nextEventCounter <= 0) {
+ if (numActiveKernels < numClusterOption.getValue() + numClusterRangeOption.getValue()) {
+ type = "Split";
+ message = splitKernel();
+ }
+ eventFinished = true;
+ }
+ break;
+ case 2:
+ if (nextEventCounter <= 0) {
+ if (numActiveKernels > 1 && numActiveKernels > numClusterOption.getValue() - numClusterRangeOption.getValue()) {
+ message = fadeOut();
+ type = "Delete";
+ }
+ eventFinished = true;
+ }
+ break;
+ case 3:
+ if (nextEventCounter <= 0) {
+ if (numActiveKernels < numClusterOption.getValue() + numClusterRangeOption.getValue()) {
+ message = fadeIn();
+ type = "Create";
+ }
+ eventFinished = true;
+ }
+ break;
+
+ }
+ if (eventFinished) {
+ nextEventCounter = (int) (eventFrequencyOption.getValue() + (instanceRandom.nextBoolean() ? -1 : 1)
+ * eventFrequencyOption.getValue() * eventFrequencyRange * instanceRandom.nextDouble());
+ nextEventChoice = getNextEvent();
+ // System.out.println("Next event choice: "+nextEventChoice);
+ }
+ if (!message.isEmpty()) {
+ message += " (numKernels = " + numActiveKernels + " at " + numGeneratedInstances + ")";
+ if (!type.equals("Merge") || message.startsWith("Clusters merging"))
+ fireClusterChange(numGeneratedInstances, type, message);
+ }
+ }
+
+ private int getNextEvent() {
+ int choice = -1;
+ boolean lowerLimit = numActiveKernels <= numClusterOption.getValue() - numClusterRangeOption.getValue();
+ boolean upperLimit = numActiveKernels >= numClusterOption.getValue() + numClusterRangeOption.getValue();
+
+ if (!lowerLimit || !upperLimit) {
+ int mode = -1;
+ if (eventDeleteCreateOption.isSet() && eventMergeSplitOption.isSet()) {
+ mode = instanceRandom.nextInt(2);
+ }
+
+ if (mode == 0 || (mode == -1 && eventMergeSplitOption.isSet())) {
+ // have we reached a limit? if not free choice
+ if (!lowerLimit && !upperLimit)
+ choice = instanceRandom.nextInt(2);
+ else
+ // we have a limit. if lower limit, choose split
+ if (lowerLimit)
+ choice = 1;
+ // otherwise we reached upper level, choose merge
+ else
+ choice = 0;
+ }
+
+ if (mode == 1 || (mode == -1 && eventDeleteCreateOption.isSet())) {
+ // have we reached a limit? if not free choice
+ if (!lowerLimit && !upperLimit)
+ choice = instanceRandom.nextInt(2) + 2;
+ else
+ // we have a limit. if lower limit, choose create
+ if (lowerLimit)
+ choice = 3;
+ // otherwise we reached upper level, choose delete
+ else
+ choice = 2;
+ }
+ }
+
+ return choice;
+ }
+
+ private String fadeOut() {
+ int id = instanceRandom.nextInt(kernels.size());
+ while (kernels.get(id).kill != -1)
+ id = instanceRandom.nextInt(kernels.size());
+
+ String message = kernels.get(id).fadeOut();
+ return message;
+ }
+
+ private String fadeIn() {
+ GeneratorCluster gc = new GeneratorCluster(clusterIdCounter++);
+ kernels.add(gc);
+ numActiveKernels++;
+ normalizeWeights();
+ return "Creating new cluster";
+ }
+
+ private String changeWeight(boolean increase) {
+ double changeRate = 0.1;
+ int id = instanceRandom.nextInt(kernels.size());
+ while (kernels.get(id).kill != -1)
+ id = instanceRandom.nextInt(kernels.size());
+
+ int sign = 1;
+ if (!increase)
+ sign = -1;
+ double weight_old = kernels.get(id).generator.getWeight();
+ double delta = sign * numActiveKernels * instanceRandom.nextDouble() * changeRate;
+ kernels.get(id).generator.setWeight(weight_old + delta);
+
+ normalizeWeights();
+
+ String message;
+ if (increase)
+ message = "Increase ";
+ else
+ message = "Decrease ";
+ message += " weight on Cluster " + id + " from " + weight_old + " to " + (weight_old + delta);
+ return message;
+
+ }
+
+ private String changeRadius(boolean increase) {
+ double maxChangeRate = 0.1;
+ int id = instanceRandom.nextInt(kernels.size());
+ while (kernels.get(id).kill != -1)
+ id = instanceRandom.nextInt(kernels.size());
+
+ int sign = 1;
+ if (!increase)
+ sign = -1;
+
+ double r_old = kernels.get(id).generator.getRadius();
+ double r_new = r_old + sign * r_old * instanceRandom.nextDouble() * maxChangeRate;
+ if (r_new >= 0.5)
+ return "Radius to big";
+ kernels.get(id).generator.setRadius(r_new);
+
+ String message;
+ if (increase)
+ message = "Increase ";
+ else
+ message = "Decrease ";
+ message += " radius on Cluster " + id + " from " + r_old + " to " + r_new;
+ return message;
+ }
+
+ private String splitKernel() {
+ int id = instanceRandom.nextInt(kernels.size());
+ while (kernels.get(id).kill != -1)
+ id = instanceRandom.nextInt(kernels.size());
+
+ String message = kernels.get(id).splitKernel();
+
+ return message;
+ }
+
+ private String mergeKernels(int steps) {
+ if (numActiveKernels > 1 && ((mergeClusterA == null && mergeClusterB == null))) {
+
+ // choose clusters to merge
+ double diseredDist = steps / speedOption.getValue() * maxDistanceMoveThresholdByStep;
+ double minDist = Double.MAX_VALUE;
+ // System.out.println("DisredDist:"+(2*diseredDist));
+ for (int i = 0; i < kernels.size(); i++) {
+ for (int j = 0; j < i; j++) {
+ if (kernels.get(i).kill != -1 || kernels.get(j).kill != -1) {
+ continue;
+ }
+ else {
+ double kernelDist = kernels.get(i).generator.getCenterDistance(kernels.get(j).generator);
+ double d = kernelDist - 2 * diseredDist;
+ // System.out.println("Dist:"+i+" / "+j+" "+d);
+ if (Math.abs(d) < minDist &&
+ (minDist != Double.MAX_VALUE || d > 0 || Math.abs(d) < 0.001)) {
+ minDist = Math.abs(d);
+ mergeClusterA = kernels.get(i);
+ mergeClusterB = kernels.get(j);
}
- incluster = false;
- if(!noiseInClusterOption.isSet() && counter > 0){
- counter--;
- for(int c = 0; c < kernels.size(); c++){
- for(int m = 0; m < kernels.get(c).microClusters.size(); m++){
- Instance inst = new DenseInstance(1, sample);
- if(kernels.get(c).microClusters.get(m).getInclusionProbability(inst) > 0){
- incluster = true;
- break;
- }
- }
- if(incluster)
- break;
- }
+ }
+ }
+ }
+
+ if (mergeClusterA != null && mergeClusterB != null) {
+ double[] merge_point = mergeClusterA.generator.getCenter();
+ double[] v = mergeClusterA.generator.getDistanceVector(mergeClusterB.generator);
+ for (int i = 0; i < v.length; i++) {
+ merge_point[i] = merge_point[i] + v[i] * 0.5;
+ }
+
+ mergeClusterA.merging = true;
+ mergeClusterB.merging = true;
+ mergeClusterA.setDesitnation(merge_point);
+ mergeClusterB.setDesitnation(merge_point);
+
+ if (debug) {
+ System.out.println("Center1" + Arrays.toString(mergeClusterA.generator.getCenter()));
+ System.out.println("Center2" + Arrays.toString(mergeClusterB.generator.getCenter()));
+ System.out.println("Vector" + Arrays.toString(v));
+
+ System.out.println("Try to merge cluster " + mergeClusterA.generator.getId() +
+ " into " + mergeClusterB.generator.getId() +
+ " at " + Arrays.toString(merge_point) +
+ " time " + numGeneratedInstances);
+ }
+ return "Init merge";
+ }
+ }
+
+ if (mergeClusterA != null && mergeClusterB != null) {
+
+ // movekernels will move the kernels close to each other,
+ // we just need to check and merge here if they are close enough
+ return mergeClusterA.tryMerging(mergeClusterB);
+ }
+
+ return "";
+ }
+
+ /************************* TOOLS **************************************/
+
+ public void getDescription(StringBuilder sb, int indent) {
+
+ }
+
+ private double[] getNoisePoint() {
+ double[] sample = new double[numAttsOption.getValue()];
+ boolean incluster = true;
+ int counter = 20;
+ while (incluster) {
+ for (int j = 0; j < numAttsOption.getValue(); j++) {
+ sample[j] = instanceRandom.nextDouble();
+ }
+ incluster = false;
+ if (!noiseInClusterOption.isSet() && counter > 0) {
+ counter--;
+ for (int c = 0; c < kernels.size(); c++) {
+ for (int m = 0; m < kernels.get(c).microClusters.size(); m++) {
+ Instance inst = new DenseInstance(1, sample);
+ if (kernels.get(c).microClusters.get(m).getInclusionProbability(inst) > 0) {
+ incluster = true;
+ break;
}
+ }
+ if (incluster)
+ break;
}
-
-// double [] sample = new double [numAttsOption.getValue()];
-// for (int j = 0; j < numAttsOption.getValue(); j++) {
-// sample[j] = instanceRandom.nextDouble();
-// }
-
- return sample;
+ }
}
- private int chooseWeightedElement() {
- double r = instanceRandom.nextDouble();
+ // double [] sample = new double [numAttsOption.getValue()];
+ // for (int j = 0; j < numAttsOption.getValue(); j++) {
+ // sample[j] = instanceRandom.nextDouble();
+ // }
- // Determine index of choosen element
- int i = 0;
- while (r > 0.0) {
- r -= kernels.get(i).generator.getWeight();
- i++;
- }
- --i; // Overcounted once
- //System.out.println(i);
- return i;
+ return sample;
+ }
+
+ private int chooseWeightedElement() {
+ double r = instanceRandom.nextDouble();
+
+ // Determine index of choosen element
+ int i = 0;
+ while (r > 0.0) {
+ r -= kernels.get(i).generator.getWeight();
+ i++;
}
+ --i; // Overcounted once
+ // System.out.println(i);
+ return i;
+ }
- private void normalizeWeights(){
- double sumWeights = 0.0;
- for (int i = 0; i < kernels.size(); i++) {
- sumWeights+=kernels.get(i).generator.getWeight();
- }
- for (int i = 0; i < kernels.size(); i++) {
- kernels.get(i).generator.setWeight(kernels.get(i).generator.getWeight()/sumWeights);
- }
+ private void normalizeWeights() {
+ double sumWeights = 0.0;
+ for (int i = 0; i < kernels.size(); i++) {
+ sumWeights += kernels.get(i).generator.getWeight();
}
+ for (int i = 0; i < kernels.size(); i++) {
+ kernels.get(i).generator.setWeight(kernels.get(i).generator.getWeight() / sumWeights);
+ }
+ }
+ /*************** EVENT Listener *********************/
+ // should go into the superclass of the generator, create new one for cluster
+ // streams?
-
- /*************** EVENT Listener *********************/
- // should go into the superclass of the generator, create new one for cluster streams?
-
/** Add a listener */
synchronized public void addClusterChangeListener(ClusterEventListener l) {
if (listeners == null)
@@ -939,17 +935,17 @@
if (listeners != null && !listeners.isEmpty()) {
// create the event object to send
ClusterEvent event =
- new ClusterEvent(this, timestamp, type , message);
+ new ClusterEvent(this, timestamp, type, message);
// make a copy of the listener list in case
- // anyone adds/removes listeners
+ // anyone adds/removes listeners
Vector targets;
synchronized (this) {
targets = (Vector) listeners.clone();
}
// walk through the listener list and
- // call the sunMoved method in each
+ // call the sunMoved method in each
Enumeration e = targets.elements();
while (e.hasMoreElements()) {
ClusterEventListener l = (ClusterEventListener) e.nextElement();
@@ -959,17 +955,13 @@
}
}
- @Override
- public String getPurposeString() {
- return "Generates a random radial basis function stream.";
- }
+ @Override
+ public String getPurposeString() {
+ return "Generates a random radial basis function stream.";
+ }
-
- public String getParameterString(){
- return "";
- }
-
-
-
+ public String getParameterString() {
+ return "";
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/generators/HyperplaneGenerator.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/generators/HyperplaneGenerator.java
index d0ab735..873ab90 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/generators/HyperplaneGenerator.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/generators/HyperplaneGenerator.java
@@ -45,134 +45,141 @@
*/
public class HyperplaneGenerator extends AbstractOptionHandler implements InstanceStream {
- @Override
- public String getPurposeString() {
- return "Generates a problem of predicting class of a rotating hyperplane.";
+ @Override
+ public String getPurposeString() {
+ return "Generates a problem of predicting class of a rotating hyperplane.";
+ }
+
+ private static final long serialVersionUID = 1L;
+
+ public IntOption instanceRandomSeedOption = new IntOption("instanceRandomSeed", 'i',
+ "Seed for random generation of instances.", 1);
+
+ public IntOption numClassesOption = new IntOption("numClasses", 'c', "The number of classes to generate.", 2, 2,
+ Integer.MAX_VALUE);
+
+ public IntOption numAttsOption = new IntOption("numAtts", 'a', "The number of attributes to generate.", 10, 0,
+ Integer.MAX_VALUE);
+
+ public IntOption numDriftAttsOption = new IntOption("numDriftAtts", 'k', "The number of attributes with drift.", 2,
+ 0, Integer.MAX_VALUE);
+
+ public FloatOption magChangeOption = new FloatOption("magChange", 't', "Magnitude of the change for every example",
+ 0.0, 0.0, 1.0);
+
+ public IntOption noisePercentageOption = new IntOption("noisePercentage", 'n',
+ "Percentage of noise to add to the data.", 5, 0, 100);
+
+ public IntOption sigmaPercentageOption = new IntOption("sigmaPercentage", 's',
+ "Percentage of probability that the direction of change is reversed.", 10,
+ 0, 100);
+
+ protected InstancesHeader streamHeader;
+
+ protected Random instanceRandom;
+
+ protected double[] weights;
+
+ protected int[] sigma;
+
+ public int numberInstance;
+
+ @Override
+ protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ monitor.setCurrentActivity("Preparing hyperplane...", -1.0);
+ generateHeader();
+ restart();
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ protected void generateHeader() {
+ FastVector attributes = new FastVector();
+ for (int i = 0; i < this.numAttsOption.getValue(); i++) {
+ attributes.addElement(new Attribute("att" + (i + 1)));
}
- private static final long serialVersionUID = 1L;
+ FastVector classLabels = new FastVector();
+ for (int i = 0; i < this.numClassesOption.getValue(); i++) {
+ classLabels.addElement("class" + (i + 1));
+ }
+ attributes.addElement(new Attribute("class", classLabels));
+ this.streamHeader = new InstancesHeader(new Instances(getCLICreationString(InstanceStream.class), attributes, 0));
+ this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
+ }
- public IntOption instanceRandomSeedOption = new IntOption("instanceRandomSeed", 'i', "Seed for random generation of instances.", 1);
+ @Override
+ public long estimatedRemainingInstances() {
+ return -1;
+ }
- public IntOption numClassesOption = new IntOption("numClasses", 'c', "The number of classes to generate.", 2, 2, Integer.MAX_VALUE);
+ @Override
+ public InstancesHeader getHeader() {
+ return this.streamHeader;
+ }
- public IntOption numAttsOption = new IntOption("numAtts", 'a', "The number of attributes to generate.", 10, 0, Integer.MAX_VALUE);
+ @Override
+ public boolean hasMoreInstances() {
+ return true;
+ }
- public IntOption numDriftAttsOption = new IntOption("numDriftAtts", 'k', "The number of attributes with drift.", 2, 0, Integer.MAX_VALUE);
+ @Override
+ public boolean isRestartable() {
+ return true;
+ }
- public FloatOption magChangeOption = new FloatOption("magChange", 't', "Magnitude of the change for every example", 0.0, 0.0, 1.0);
+ @Override
+ public Example<Instance> nextInstance() {
- public IntOption noisePercentageOption = new IntOption("noisePercentage", 'n', "Percentage of noise to add to the data.", 5, 0, 100);
-
- public IntOption sigmaPercentageOption = new IntOption("sigmaPercentage", 's', "Percentage of probability that the direction of change is reversed.", 10,
- 0, 100);
-
- protected InstancesHeader streamHeader;
-
- protected Random instanceRandom;
-
- protected double[] weights;
-
- protected int[] sigma;
-
- public int numberInstance;
-
- @Override
- protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- monitor.setCurrentActivity("Preparing hyperplane...", -1.0);
- generateHeader();
- restart();
+ int numAtts = this.numAttsOption.getValue();
+ double[] attVals = new double[numAtts + 1];
+ double sum = 0.0;
+ double sumWeights = 0.0;
+ for (int i = 0; i < numAtts; i++) {
+ attVals[i] = this.instanceRandom.nextDouble();
+ sum += this.weights[i] * attVals[i];
+ sumWeights += this.weights[i];
+ }
+ int classLabel;
+ if (sum >= sumWeights * 0.5) {
+ classLabel = 1;
+ } else {
+ classLabel = 0;
+ }
+ // Add Noise
+ if ((1 + (this.instanceRandom.nextInt(100))) <= this.noisePercentageOption.getValue()) {
+ classLabel = (classLabel == 0 ? 1 : 0);
}
- @SuppressWarnings({ "rawtypes", "unchecked" })
- protected void generateHeader() {
- FastVector attributes = new FastVector();
- for (int i = 0; i < this.numAttsOption.getValue(); i++) {
- attributes.addElement(new Attribute("att" + (i + 1)));
- }
+ Instance inst = new DenseInstance(1.0, attVals);
+ inst.setDataset(getHeader());
+ inst.setClassValue(classLabel);
+ addDrift();
+ return new InstanceExample(inst);
+ }
- FastVector classLabels = new FastVector();
- for (int i = 0; i < this.numClassesOption.getValue(); i++) {
- classLabels.addElement("class" + (i + 1));
- }
- attributes.addElement(new Attribute("class", classLabels));
- this.streamHeader = new InstancesHeader(new Instances(getCLICreationString(InstanceStream.class), attributes, 0));
- this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
+ private void addDrift() {
+ for (int i = 0; i < this.numDriftAttsOption.getValue(); i++) {
+ this.weights[i] += (double) ((double) sigma[i]) * ((double) this.magChangeOption.getValue());
+ if (// this.weights[i] >= 1.0 || this.weights[i] <= 0.0 ||
+ (1 + (this.instanceRandom.nextInt(100))) <= this.sigmaPercentageOption.getValue()) {
+ this.sigma[i] *= -1;
+ }
}
+ }
- @Override
- public long estimatedRemainingInstances() {
- return -1;
+ @Override
+ public void restart() {
+ this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
+ this.weights = new double[this.numAttsOption.getValue()];
+ this.sigma = new int[this.numAttsOption.getValue()];
+ for (int i = 0; i < this.numAttsOption.getValue(); i++) {
+ this.weights[i] = this.instanceRandom.nextDouble();
+ this.sigma[i] = (i < this.numDriftAttsOption.getValue() ? 1 : 0);
}
+ }
- @Override
- public InstancesHeader getHeader() {
- return this.streamHeader;
- }
-
- @Override
- public boolean hasMoreInstances() {
- return true;
- }
-
- @Override
- public boolean isRestartable() {
- return true;
- }
-
- @Override
- public Example<Instance> nextInstance() {
-
- int numAtts = this.numAttsOption.getValue();
- double[] attVals = new double[numAtts + 1];
- double sum = 0.0;
- double sumWeights = 0.0;
- for (int i = 0; i < numAtts; i++) {
- attVals[i] = this.instanceRandom.nextDouble();
- sum += this.weights[i] * attVals[i];
- sumWeights += this.weights[i];
- }
- int classLabel;
- if (sum >= sumWeights * 0.5) {
- classLabel = 1;
- } else {
- classLabel = 0;
- }
- // Add Noise
- if ((1 + (this.instanceRandom.nextInt(100))) <= this.noisePercentageOption.getValue()) {
- classLabel = (classLabel == 0 ? 1 : 0);
- }
-
- Instance inst = new DenseInstance(1.0, attVals);
- inst.setDataset(getHeader());
- inst.setClassValue(classLabel);
- addDrift();
- return new InstanceExample(inst);
- }
-
- private void addDrift() {
- for (int i = 0; i < this.numDriftAttsOption.getValue(); i++) {
- this.weights[i] += (double) ((double) sigma[i]) * ((double) this.magChangeOption.getValue());
- if (// this.weights[i] >= 1.0 || this.weights[i] <= 0.0 ||
- (1 + (this.instanceRandom.nextInt(100))) <= this.sigmaPercentageOption.getValue()) {
- this.sigma[i] *= -1;
- }
- }
- }
-
- @Override
- public void restart() {
- this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
- this.weights = new double[this.numAttsOption.getValue()];
- this.sigma = new int[this.numAttsOption.getValue()];
- for (int i = 0; i < this.numAttsOption.getValue(); i++) {
- this.weights[i] = this.instanceRandom.nextDouble();
- this.sigma[i] = (i < this.numDriftAttsOption.getValue() ? 1 : 0);
- }
- }
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/generators/RandomTreeGenerator.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/generators/RandomTreeGenerator.java
index a7f982e..aa713cc 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/generators/RandomTreeGenerator.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/streams/generators/RandomTreeGenerator.java
@@ -41,225 +41,227 @@
/**
* Stream generator for a stream based on a randomly generated tree..
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class RandomTreeGenerator extends AbstractOptionHandler implements InstanceStream {
- @Override
- public String getPurposeString() {
- return "Generates a stream based on a randomly generated tree.";
- }
+ @Override
+ public String getPurposeString() {
+ return "Generates a stream based on a randomly generated tree.";
+ }
+
+ private static final long serialVersionUID = 1L;
+
+ public IntOption treeRandomSeedOption = new IntOption("treeRandomSeed",
+ 'r', "Seed for random generation of tree.", 1);
+
+ public IntOption instanceRandomSeedOption = new IntOption(
+ "instanceRandomSeed", 'i',
+ "Seed for random generation of instances.", 1);
+
+ public IntOption numClassesOption = new IntOption("numClasses", 'c',
+ "The number of classes to generate.", 2, 2, Integer.MAX_VALUE);
+
+ public IntOption numNominalsOption = new IntOption("numNominals", 'o',
+ "The number of nominal attributes to generate.", 5, 0,
+ Integer.MAX_VALUE);
+
+ public IntOption numNumericsOption = new IntOption("numNumerics", 'u',
+ "The number of numeric attributes to generate.", 5, 0,
+ Integer.MAX_VALUE);
+
+ public IntOption numValsPerNominalOption = new IntOption(
+ "numValsPerNominal", 'v',
+ "The number of values to generate per nominal attribute.", 5, 2,
+ Integer.MAX_VALUE);
+
+ public IntOption maxTreeDepthOption = new IntOption("maxTreeDepth", 'd',
+ "The maximum depth of the tree concept.", 5, 0, Integer.MAX_VALUE);
+
+ public IntOption firstLeafLevelOption = new IntOption(
+ "firstLeafLevel",
+ 'l',
+ "The first level of the tree above maxTreeDepth that can have leaves.",
+ 3, 0, Integer.MAX_VALUE);
+
+ public FloatOption leafFractionOption = new FloatOption("leafFraction",
+ 'f',
+ "The fraction of leaves per level from firstLeafLevel onwards.",
+ 0.15, 0.0, 1.0);
+
+ protected static class Node implements Serializable {
private static final long serialVersionUID = 1L;
- public IntOption treeRandomSeedOption = new IntOption("treeRandomSeed",
- 'r', "Seed for random generation of tree.", 1);
+ public int classLabel;
- public IntOption instanceRandomSeedOption = new IntOption(
- "instanceRandomSeed", 'i',
- "Seed for random generation of instances.", 1);
+ public int splitAttIndex;
- public IntOption numClassesOption = new IntOption("numClasses", 'c',
- "The number of classes to generate.", 2, 2, Integer.MAX_VALUE);
+ public double splitAttValue;
- public IntOption numNominalsOption = new IntOption("numNominals", 'o',
- "The number of nominal attributes to generate.", 5, 0,
- Integer.MAX_VALUE);
+ public Node[] children;
+ }
- public IntOption numNumericsOption = new IntOption("numNumerics", 'u',
- "The number of numeric attributes to generate.", 5, 0,
- Integer.MAX_VALUE);
+ protected Node treeRoot;
- public IntOption numValsPerNominalOption = new IntOption(
- "numValsPerNominal", 'v',
- "The number of values to generate per nominal attribute.", 5, 2,
- Integer.MAX_VALUE);
+ protected InstancesHeader streamHeader;
- public IntOption maxTreeDepthOption = new IntOption("maxTreeDepth", 'd',
- "The maximum depth of the tree concept.", 5, 0, Integer.MAX_VALUE);
+ protected Random instanceRandom;
- public IntOption firstLeafLevelOption = new IntOption(
- "firstLeafLevel",
- 'l',
- "The first level of the tree above maxTreeDepth that can have leaves.",
- 3, 0, Integer.MAX_VALUE);
+ @Override
+ public void prepareForUseImpl(TaskMonitor monitor,
+ ObjectRepository repository) {
+ monitor.setCurrentActivity("Preparing random tree...", -1.0);
+ generateHeader();
+ generateRandomTree();
+ restart();
+ }
- public FloatOption leafFractionOption = new FloatOption("leafFraction",
- 'f',
- "The fraction of leaves per level from firstLeafLevel onwards.",
- 0.15, 0.0, 1.0);
+ @Override
+ public long estimatedRemainingInstances() {
+ return -1;
+ }
- protected static class Node implements Serializable {
+ @Override
+ public boolean isRestartable() {
+ return true;
+ }
- private static final long serialVersionUID = 1L;
+ @Override
+ public void restart() {
+ this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
+ }
- public int classLabel;
+ @Override
+ public InstancesHeader getHeader() {
+ return this.streamHeader;
+ }
- public int splitAttIndex;
+ @Override
+ public boolean hasMoreInstances() {
+ return true;
+ }
- public double splitAttValue;
-
- public Node[] children;
+ @Override
+ public InstanceExample nextInstance() {
+ double[] attVals = new double[this.numNominalsOption.getValue()
+ + this.numNumericsOption.getValue()];
+ InstancesHeader header = getHeader();
+ Instance inst = new DenseInstance(header.numAttributes());
+ for (int i = 0; i < attVals.length; i++) {
+ attVals[i] = i < this.numNominalsOption.getValue() ? this.instanceRandom.nextInt(this.numValsPerNominalOption
+ .getValue())
+ : this.instanceRandom.nextDouble();
+ inst.setValue(i, attVals[i]);
}
+ inst.setDataset(header);
+ inst.setClassValue(classifyInstance(this.treeRoot, attVals));
+ return new InstanceExample(inst);
+ }
- protected Node treeRoot;
-
- protected InstancesHeader streamHeader;
-
- protected Random instanceRandom;
-
- @Override
- public void prepareForUseImpl(TaskMonitor monitor,
- ObjectRepository repository) {
- monitor.setCurrentActivity("Preparing random tree...", -1.0);
- generateHeader();
- generateRandomTree();
- restart();
+ protected int classifyInstance(Node node, double[] attVals) {
+ if (node.children == null) {
+ return node.classLabel;
}
-
- @Override
- public long estimatedRemainingInstances() {
- return -1;
+ if (node.splitAttIndex < this.numNominalsOption.getValue()) {
+ return classifyInstance(
+ node.children[(int) attVals[node.splitAttIndex]], attVals);
}
+ return classifyInstance(
+ node.children[attVals[node.splitAttIndex] < node.splitAttValue ? 0
+ : 1], attVals);
+ }
- @Override
- public boolean isRestartable() {
- return true;
+ protected void generateHeader() {
+ FastVector<Attribute> attributes = new FastVector<>();
+ FastVector<String> nominalAttVals = new FastVector<>();
+ for (int i = 0; i < this.numValsPerNominalOption.getValue(); i++) {
+ nominalAttVals.addElement("value" + (i + 1));
}
+ for (int i = 0; i < this.numNominalsOption.getValue(); i++) {
+ attributes.addElement(new Attribute("nominal" + (i + 1),
+ nominalAttVals));
+ }
+ for (int i = 0; i < this.numNumericsOption.getValue(); i++) {
+ attributes.addElement(new Attribute("numeric" + (i + 1)));
+ }
+ FastVector<String> classLabels = new FastVector<>();
+ for (int i = 0; i < this.numClassesOption.getValue(); i++) {
+ classLabels.addElement("class" + (i + 1));
+ }
+ attributes.addElement(new Attribute("class", classLabels));
+ this.streamHeader = new InstancesHeader(new Instances(
+ getCLICreationString(InstanceStream.class), attributes, 0));
+ this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
+ }
- @Override
- public void restart() {
- this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue());
+ protected void generateRandomTree() {
+ Random treeRand = new Random(this.treeRandomSeedOption.getValue());
+ ArrayList<Integer> nominalAttCandidates = new ArrayList<>(
+ this.numNominalsOption.getValue());
+ for (int i = 0; i < this.numNominalsOption.getValue(); i++) {
+ nominalAttCandidates.add(i);
}
+ double[] minNumericVals = new double[this.numNumericsOption.getValue()];
+ double[] maxNumericVals = new double[this.numNumericsOption.getValue()];
+ for (int i = 0; i < this.numNumericsOption.getValue(); i++) {
+ minNumericVals[i] = 0.0;
+ maxNumericVals[i] = 1.0;
+ }
+ this.treeRoot = generateRandomTreeNode(0, nominalAttCandidates,
+ minNumericVals, maxNumericVals, treeRand);
+ }
- @Override
- public InstancesHeader getHeader() {
- return this.streamHeader;
+ protected Node generateRandomTreeNode(int currentDepth,
+ ArrayList<Integer> nominalAttCandidates, double[] minNumericVals,
+ double[] maxNumericVals, Random treeRand) {
+ if ((currentDepth >= this.maxTreeDepthOption.getValue())
+ || ((currentDepth >= this.firstLeafLevelOption.getValue()) && (this.leafFractionOption.getValue() >= (1.0 - treeRand
+ .nextDouble())))) {
+ Node leaf = new Node();
+ leaf.classLabel = treeRand.nextInt(this.numClassesOption.getValue());
+ return leaf;
}
+ Node node = new Node();
+ int chosenAtt = treeRand.nextInt(nominalAttCandidates.size()
+ + this.numNumericsOption.getValue());
+ if (chosenAtt < nominalAttCandidates.size()) {
+ node.splitAttIndex = nominalAttCandidates.get(chosenAtt);
+ node.children = new Node[this.numValsPerNominalOption.getValue()];
+ ArrayList<Integer> newNominalCandidates = new ArrayList<>(
+ nominalAttCandidates);
+ newNominalCandidates.remove(new Integer(node.splitAttIndex));
+ newNominalCandidates.trimToSize();
+ for (int i = 0; i < node.children.length; i++) {
+ node.children[i] = generateRandomTreeNode(currentDepth + 1,
+ newNominalCandidates, minNumericVals, maxNumericVals,
+ treeRand);
+ }
+ } else {
+ int numericIndex = chosenAtt - nominalAttCandidates.size();
+ node.splitAttIndex = this.numNominalsOption.getValue()
+ + numericIndex;
+ double minVal = minNumericVals[numericIndex];
+ double maxVal = maxNumericVals[numericIndex];
+ node.splitAttValue = ((maxVal - minVal) * treeRand.nextDouble())
+ + minVal;
+ node.children = new Node[2];
+ double[] newMaxVals = maxNumericVals.clone();
+ newMaxVals[numericIndex] = node.splitAttValue;
+ node.children[0] = generateRandomTreeNode(currentDepth + 1,
+ nominalAttCandidates, minNumericVals, newMaxVals, treeRand);
+ double[] newMinVals = minNumericVals.clone();
+ newMinVals[numericIndex] = node.splitAttValue;
+ node.children[1] = generateRandomTreeNode(currentDepth + 1,
+ nominalAttCandidates, newMinVals, maxNumericVals, treeRand);
+ }
+ return node;
+ }
- @Override
- public boolean hasMoreInstances() {
- return true;
- }
-
- @Override
- public InstanceExample nextInstance() {
- double[] attVals = new double[this.numNominalsOption.getValue()
- + this.numNumericsOption.getValue()];
- InstancesHeader header = getHeader();
- Instance inst = new DenseInstance(header.numAttributes());
- for (int i = 0; i < attVals.length; i++) {
- attVals[i] = i < this.numNominalsOption.getValue() ? this.instanceRandom.nextInt(this.numValsPerNominalOption.getValue())
- : this.instanceRandom.nextDouble();
- inst.setValue(i, attVals[i]);
- }
- inst.setDataset(header);
- inst.setClassValue(classifyInstance(this.treeRoot, attVals));
- return new InstanceExample(inst);
- }
-
- protected int classifyInstance(Node node, double[] attVals) {
- if (node.children == null) {
- return node.classLabel;
- }
- if (node.splitAttIndex < this.numNominalsOption.getValue()) {
- return classifyInstance(
- node.children[(int) attVals[node.splitAttIndex]], attVals);
- }
- return classifyInstance(
- node.children[attVals[node.splitAttIndex] < node.splitAttValue ? 0
- : 1], attVals);
- }
-
- protected void generateHeader() {
- FastVector<Attribute> attributes = new FastVector<>();
- FastVector<String> nominalAttVals = new FastVector<>();
- for (int i = 0; i < this.numValsPerNominalOption.getValue(); i++) {
- nominalAttVals.addElement("value" + (i + 1));
- }
- for (int i = 0; i < this.numNominalsOption.getValue(); i++) {
- attributes.addElement(new Attribute("nominal" + (i + 1),
- nominalAttVals));
- }
- for (int i = 0; i < this.numNumericsOption.getValue(); i++) {
- attributes.addElement(new Attribute("numeric" + (i + 1)));
- }
- FastVector<String> classLabels = new FastVector<>();
- for (int i = 0; i < this.numClassesOption.getValue(); i++) {
- classLabels.addElement("class" + (i + 1));
- }
- attributes.addElement(new Attribute("class", classLabels));
- this.streamHeader = new InstancesHeader(new Instances(
- getCLICreationString(InstanceStream.class), attributes, 0));
- this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1);
- }
-
- protected void generateRandomTree() {
- Random treeRand = new Random(this.treeRandomSeedOption.getValue());
- ArrayList<Integer> nominalAttCandidates = new ArrayList<>(
- this.numNominalsOption.getValue());
- for (int i = 0; i < this.numNominalsOption.getValue(); i++) {
- nominalAttCandidates.add(i);
- }
- double[] minNumericVals = new double[this.numNumericsOption.getValue()];
- double[] maxNumericVals = new double[this.numNumericsOption.getValue()];
- for (int i = 0; i < this.numNumericsOption.getValue(); i++) {
- minNumericVals[i] = 0.0;
- maxNumericVals[i] = 1.0;
- }
- this.treeRoot = generateRandomTreeNode(0, nominalAttCandidates,
- minNumericVals, maxNumericVals, treeRand);
- }
-
- protected Node generateRandomTreeNode(int currentDepth,
- ArrayList<Integer> nominalAttCandidates, double[] minNumericVals,
- double[] maxNumericVals, Random treeRand) {
- if ((currentDepth >= this.maxTreeDepthOption.getValue())
- || ((currentDepth >= this.firstLeafLevelOption.getValue()) && (this.leafFractionOption.getValue() >= (1.0 - treeRand.nextDouble())))) {
- Node leaf = new Node();
- leaf.classLabel = treeRand.nextInt(this.numClassesOption.getValue());
- return leaf;
- }
- Node node = new Node();
- int chosenAtt = treeRand.nextInt(nominalAttCandidates.size()
- + this.numNumericsOption.getValue());
- if (chosenAtt < nominalAttCandidates.size()) {
- node.splitAttIndex = nominalAttCandidates.get(chosenAtt);
- node.children = new Node[this.numValsPerNominalOption.getValue()];
- ArrayList<Integer> newNominalCandidates = new ArrayList<>(
- nominalAttCandidates);
- newNominalCandidates.remove(new Integer(node.splitAttIndex));
- newNominalCandidates.trimToSize();
- for (int i = 0; i < node.children.length; i++) {
- node.children[i] = generateRandomTreeNode(currentDepth + 1,
- newNominalCandidates, minNumericVals, maxNumericVals,
- treeRand);
- }
- } else {
- int numericIndex = chosenAtt - nominalAttCandidates.size();
- node.splitAttIndex = this.numNominalsOption.getValue()
- + numericIndex;
- double minVal = minNumericVals[numericIndex];
- double maxVal = maxNumericVals[numericIndex];
- node.splitAttValue = ((maxVal - minVal) * treeRand.nextDouble())
- + minVal;
- node.children = new Node[2];
- double[] newMaxVals = maxNumericVals.clone();
- newMaxVals[numericIndex] = node.splitAttValue;
- node.children[0] = generateRandomTreeNode(currentDepth + 1,
- nominalAttCandidates, minNumericVals, newMaxVals, treeRand);
- double[] newMinVals = minNumericVals.clone();
- newMinVals[numericIndex] = node.splitAttValue;
- node.children[1] = generateRandomTreeNode(currentDepth + 1,
- nominalAttCandidates, newMinVals, maxNumericVals, treeRand);
- }
- return node;
- }
-
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
- }
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/NullMonitor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/NullMonitor.java
index 977897f..da40835 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/NullMonitor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/NullMonitor.java
@@ -22,81 +22,81 @@
/**
* Class that represents a null monitor.
- *
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class NullMonitor implements TaskMonitor {
- @Override
- public void setCurrentActivity(String activityDescription,
- double fracComplete) {
- }
+ @Override
+ public void setCurrentActivity(String activityDescription,
+ double fracComplete) {
+ }
- @Override
- public void setCurrentActivityDescription(String activity) {
- }
+ @Override
+ public void setCurrentActivityDescription(String activity) {
+ }
- @Override
- public void setCurrentActivityFractionComplete(double fracComplete) {
- }
+ @Override
+ public void setCurrentActivityFractionComplete(double fracComplete) {
+ }
- @Override
- public boolean taskShouldAbort() {
- return false;
- }
+ @Override
+ public boolean taskShouldAbort() {
+ return false;
+ }
- @Override
- public String getCurrentActivityDescription() {
- return null;
- }
+ @Override
+ public String getCurrentActivityDescription() {
+ return null;
+ }
- @Override
- public double getCurrentActivityFractionComplete() {
- return -1.0;
- }
+ @Override
+ public double getCurrentActivityFractionComplete() {
+ return -1.0;
+ }
- @Override
- public boolean isPaused() {
- return false;
- }
+ @Override
+ public boolean isPaused() {
+ return false;
+ }
- @Override
- public boolean isCancelled() {
- return false;
- }
+ @Override
+ public boolean isCancelled() {
+ return false;
+ }
- @Override
- public void requestCancel() {
- }
+ @Override
+ public void requestCancel() {
+ }
- @Override
- public void requestPause() {
- }
+ @Override
+ public void requestPause() {
+ }
- @Override
- public void requestResume() {
- }
+ @Override
+ public void requestResume() {
+ }
- @Override
- public Object getLatestResultPreview() {
- return null;
- }
+ @Override
+ public Object getLatestResultPreview() {
+ return null;
+ }
- @Override
- public void requestResultPreview() {
- }
+ @Override
+ public void requestResultPreview() {
+ }
- @Override
- public boolean resultPreviewRequested() {
- return false;
- }
+ @Override
+ public boolean resultPreviewRequested() {
+ return false;
+ }
- @Override
- public void setLatestResultPreview(Object latestPreview) {
- }
+ @Override
+ public void setLatestResultPreview(Object latestPreview) {
+ }
- @Override
- public void requestResultPreview(ResultPreviewListener toInform) {
- }
+ @Override
+ public void requestResultPreview(ResultPreviewListener toInform) {
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/ResultPreviewListener.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/ResultPreviewListener.java
index 3fece0b..63d1236 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/ResultPreviewListener.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/ResultPreviewListener.java
@@ -21,20 +21,20 @@
*/
/**
- * Interface implemented by classes that preview results
- * on the Graphical User Interface
- *
+ * Interface implemented by classes that preview results on the Graphical User
+ * Interface
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public interface ResultPreviewListener {
- /**
- * This method is used to receive a signal from
- * <code>TaskMonitor</code> that the lastest preview has
- * changed. This method is implemented in <code>PreviewPanel</code>
- * to change the results that are shown in its panel.
- *
- */
- public void latestPreviewChanged();
+ /**
+ * This method is used to receive a signal from <code>TaskMonitor</code> that
+ * the lastest preview has changed. This method is implemented in
+ * <code>PreviewPanel</code> to change the results that are shown in its
+ * panel.
+ *
+ */
+ public void latestPreviewChanged();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/Task.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/Task.java
index d2c96a8..d3dcc4f 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/Task.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/Task.java
@@ -24,38 +24,38 @@
import com.yahoo.labs.samoa.moa.core.ObjectRepository;
/**
- * Interface representing a task.
- *
+ * Interface representing a task.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
- * @version $Revision: 7 $
+ * @version $Revision: 7 $
*/
public interface Task extends MOAObject {
- /**
- * Gets the result type of this task.
- * Tasks can return LearningCurve, LearningEvaluation,
- * Classifier, String, Instances..
- *
- * @return a class object of the result of this task
- */
- public Class<?> getTaskResultType();
+ /**
+ * Gets the result type of this task. Tasks can return LearningCurve,
+ * LearningEvaluation, Classifier, String, Instances..
+ *
+ * @return a class object of the result of this task
+ */
+ public Class<?> getTaskResultType();
- /**
- * This method performs this task,
- * when TaskMonitor and ObjectRepository are no needed.
- *
- * @return an object with the result of this task
- */
- public Object doTask();
+ /**
+ * This method performs this task, when TaskMonitor and ObjectRepository are
+ * no needed.
+ *
+ * @return an object with the result of this task
+ */
+ public Object doTask();
- /**
- * This method performs this task.
- * <code>AbstractTask</code> implements this method so all
- * its extensions only need to implement <code>doTaskImpl</code>
- *
- * @param monitor the TaskMonitor to use
- * @param repository the ObjectRepository to use
- * @return an object with the result of this task
- */
- public Object doTask(TaskMonitor monitor, ObjectRepository repository);
+ /**
+ * This method performs this task. <code>AbstractTask</code> implements this
+ * method so all its extensions only need to implement <code>doTaskImpl</code>
+ *
+ * @param monitor
+ * the TaskMonitor to use
+ * @param repository
+ * the ObjectRepository to use
+ * @return an object with the result of this task
+ */
+ public Object doTask(TaskMonitor monitor, ObjectRepository repository);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/TaskMonitor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/TaskMonitor.java
index 4918cd8..7b37f05 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/TaskMonitor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/tasks/TaskMonitor.java
@@ -21,120 +21,126 @@
*/
/**
- * Interface representing a task monitor.
- *
+ * Interface representing a task monitor.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
- * @version $Revision: 7 $
+ * @version $Revision: 7 $
*/
public interface TaskMonitor {
- /**
- * Sets the description and the percentage done of the current activity.
- *
- * @param activity the description of the current activity
- * @param fracComplete the percentage done of the current activity
- */
- public void setCurrentActivity(String activityDescription,
- double fracComplete);
+ /**
+ * Sets the description and the percentage done of the current activity.
+ *
+ * @param activity
+ * the description of the current activity
+ * @param fracComplete
+ * the percentage done of the current activity
+ */
+ public void setCurrentActivity(String activityDescription,
+ double fracComplete);
- /**
- * Sets the description of the current activity.
- *
- * @param activity the description of the current activity
- */
- public void setCurrentActivityDescription(String activity);
+ /**
+ * Sets the description of the current activity.
+ *
+ * @param activity
+ * the description of the current activity
+ */
+ public void setCurrentActivityDescription(String activity);
- /**
- * Sets the percentage done of the current activity
- *
- * @param fracComplete the percentage done of the current activity
- */
- public void setCurrentActivityFractionComplete(double fracComplete);
+ /**
+ * Sets the percentage done of the current activity
+ *
+ * @param fracComplete
+ * the percentage done of the current activity
+ */
+ public void setCurrentActivityFractionComplete(double fracComplete);
- /**
- * Gets whether the task should abort.
- *
- * @return true if the task should abort
- */
- public boolean taskShouldAbort();
+ /**
+ * Gets whether the task should abort.
+ *
+ * @return true if the task should abort
+ */
+ public boolean taskShouldAbort();
- /**
- * Gets whether there is a request for preview the task result.
- *
- * @return true if there is a request for preview the task result
- */
- public boolean resultPreviewRequested();
+ /**
+ * Gets whether there is a request for preview the task result.
+ *
+ * @return true if there is a request for preview the task result
+ */
+ public boolean resultPreviewRequested();
- /**
- * Sets the current result to preview
- *
- * @param latestPreview the result to preview
- */
- public void setLatestResultPreview(Object latestPreview);
+ /**
+ * Sets the current result to preview
+ *
+ * @param latestPreview
+ * the result to preview
+ */
+ public void setLatestResultPreview(Object latestPreview);
- /**
- * Gets the description of the current activity.
- *
- * @return the description of the current activity
- */
- public String getCurrentActivityDescription();
+ /**
+ * Gets the description of the current activity.
+ *
+ * @return the description of the current activity
+ */
+ public String getCurrentActivityDescription();
- /**
- * Gets the percentage done of the current activity
- *
- * @return the percentage done of the current activity
- */
- public double getCurrentActivityFractionComplete();
+ /**
+ * Gets the percentage done of the current activity
+ *
+ * @return the percentage done of the current activity
+ */
+ public double getCurrentActivityFractionComplete();
- /**
- * Requests the task monitored to pause.
- *
- */
- public void requestPause();
+ /**
+ * Requests the task monitored to pause.
+ *
+ */
+ public void requestPause();
- /**
- * Requests the task monitored to resume.
- *
- */
- public void requestResume();
+ /**
+ * Requests the task monitored to resume.
+ *
+ */
+ public void requestResume();
- /**
- * Requests the task monitored to cancel.
- *
- */
- public void requestCancel();
+ /**
+ * Requests the task monitored to cancel.
+ *
+ */
+ public void requestCancel();
- /**
- * Gets whether the task monitored is paused.
- *
- * @return true if the task is paused
- */
- public boolean isPaused();
+ /**
+ * Gets whether the task monitored is paused.
+ *
+ * @return true if the task is paused
+ */
+ public boolean isPaused();
- /**
- * Gets whether the task monitored is cancelled.
- *
- * @return true if the task is cancelled
- */
- public boolean isCancelled();
+ /**
+ * Gets whether the task monitored is cancelled.
+ *
+ * @return true if the task is cancelled
+ */
+ public boolean isCancelled();
- /**
- * Requests to preview the task result.
- *
- */
- public void requestResultPreview();
+ /**
+ * Requests to preview the task result.
+ *
+ */
+ public void requestResultPreview();
- /**
- * Requests to preview the task result.
- *
- * @param toInform the listener of the changes in the preview of the result
- */
- public void requestResultPreview(ResultPreviewListener toInform);
+ /**
+ * Requests to preview the task result.
+ *
+ * @param toInform
+ * the listener of the changes in the preview of the result
+ */
+ public void requestResultPreview(ResultPreviewListener toInform);
- /**
- * Gets the current result to preview
- *
- * @return the result to preview
- */
- public Object getLatestResultPreview();
+ /**
+ * Gets the current result to preview
+ *
+ * @return the result to preview
+ */
+ public Object getLatestResultPreview();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/ArffFileStream.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/ArffFileStream.java
index 93fc7c4..80a4910 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/ArffFileStream.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/ArffFileStream.java
@@ -31,89 +31,89 @@
/**
* InstanceStream for ARFF file
+ *
* @author Casey
*/
public class ArffFileStream extends FileStream {
- public FileOption arffFileOption = new FileOption("arffFile", 'f',
- "ARFF File(s) to load.", null, null, false);
+ public FileOption arffFileOption = new FileOption("arffFile", 'f',
+ "ARFF File(s) to load.", null, null, false);
- public IntOption classIndexOption = new IntOption("classIndex", 'c',
- "Class index of data. 0 for none or -1 for last attribute in file.",
- -1, -1, Integer.MAX_VALUE);
-
- protected InstanceExample lastInstanceRead;
+ public IntOption classIndexOption = new IntOption("classIndex", 'c',
+ "Class index of data. 0 for none or -1 for last attribute in file.",
+ -1, -1, Integer.MAX_VALUE);
- @Override
- public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- super.prepareForUseImpl(monitor, repository);
- String filePath = this.arffFileOption.getFile().getAbsolutePath();
- this.fileSource.init(filePath, "arff");
- this.lastInstanceRead = null;
- }
-
- @Override
- protected void reset() {
- try {
- if (this.fileReader != null)
- this.fileReader.close();
+ protected InstanceExample lastInstanceRead;
- fileSource.reset();
- }
- catch (IOException ioe) {
- throw new RuntimeException("FileStream restart failed.", ioe);
- }
+ @Override
+ public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ super.prepareForUseImpl(monitor, repository);
+ String filePath = this.arffFileOption.getFile().getAbsolutePath();
+ this.fileSource.init(filePath, "arff");
+ this.lastInstanceRead = null;
+ }
- if (!getNextFileReader()) {
- hitEndOfStream = true;
- throw new RuntimeException("FileStream is empty.");
- }
- }
-
- @Override
- protected boolean getNextFileReader() {
- boolean ret = super.getNextFileReader();
- if (ret) {
- this.instances = new Instances(this.fileReader, 1, -1);
- if (this.classIndexOption.getValue() < 0) {
- this.instances.setClassIndex(this.instances.numAttributes() - 1);
- } else if (this.classIndexOption.getValue() > 0) {
- this.instances.setClassIndex(this.classIndexOption.getValue() - 1);
- }
- }
- return ret;
- }
-
- @Override
- protected boolean readNextInstanceFromFile() {
- try {
- if (this.instances.readInstance(this.fileReader)) {
- this.lastInstanceRead = new InstanceExample(this.instances.instance(0));
- this.instances.delete(); // keep instances clean
- return true;
- }
- if (this.fileReader != null) {
- this.fileReader.close();
- this.fileReader = null;
- }
- return false;
- } catch (IOException ioe) {
- throw new RuntimeException(
- "ArffFileStream failed to read instance from stream.", ioe);
- }
+ @Override
+ protected void reset() {
+ try {
+ if (this.fileReader != null)
+ this.fileReader.close();
- }
-
- @Override
- protected InstanceExample getLastInstanceRead() {
- return this.lastInstanceRead;
- }
-
- /*
- * extend com.yahoo.labs.samoa.moa.MOAObject
- */
- @Override
- public void getDescription(StringBuilder sb, int indent) {
- // TODO Auto-generated method stub
+ fileSource.reset();
+ } catch (IOException ioe) {
+ throw new RuntimeException("FileStream restart failed.", ioe);
}
+
+ if (!getNextFileReader()) {
+ hitEndOfStream = true;
+ throw new RuntimeException("FileStream is empty.");
+ }
+ }
+
+ @Override
+ protected boolean getNextFileReader() {
+ boolean ret = super.getNextFileReader();
+ if (ret) {
+ this.instances = new Instances(this.fileReader, 1, -1);
+ if (this.classIndexOption.getValue() < 0) {
+ this.instances.setClassIndex(this.instances.numAttributes() - 1);
+ } else if (this.classIndexOption.getValue() > 0) {
+ this.instances.setClassIndex(this.classIndexOption.getValue() - 1);
+ }
+ }
+ return ret;
+ }
+
+ @Override
+ protected boolean readNextInstanceFromFile() {
+ try {
+ if (this.instances.readInstance(this.fileReader)) {
+ this.lastInstanceRead = new InstanceExample(this.instances.instance(0));
+ this.instances.delete(); // keep instances clean
+ return true;
+ }
+ if (this.fileReader != null) {
+ this.fileReader.close();
+ this.fileReader = null;
+ }
+ return false;
+ } catch (IOException ioe) {
+ throw new RuntimeException(
+ "ArffFileStream failed to read instance from stream.", ioe);
+ }
+
+ }
+
+ @Override
+ protected InstanceExample getLastInstanceRead() {
+ return this.lastInstanceRead;
+ }
+
+ /*
+ * extend com.yahoo.labs.samoa.moa.MOAObject
+ */
+ @Override
+ public void getDescription(StringBuilder sb, int indent) {
+ // TODO Auto-generated method stub
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/ClusteringEntranceProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/ClusteringEntranceProcessor.java
index 20f3feb..ddb047e 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/ClusteringEntranceProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/ClusteringEntranceProcessor.java
@@ -45,127 +45,167 @@
*/
public final class ClusteringEntranceProcessor implements EntranceProcessor {
- private static final long serialVersionUID = 4169053337917578558L;
+ private static final long serialVersionUID = 4169053337917578558L;
- private static final Logger logger = LoggerFactory.getLogger(ClusteringEntranceProcessor.class);
+ private static final Logger logger = LoggerFactory.getLogger(ClusteringEntranceProcessor.class);
- private StreamSource streamSource;
- private Instance firstInstance;
- private boolean isInited = false;
- private Random random = new Random();
- private double samplingThreshold;
- private int numberInstances;
- private int numInstanceSent = 0;
+ private StreamSource streamSource;
+ private Instance firstInstance;
+ private boolean isInited = false;
+ private Random random = new Random();
+ private double samplingThreshold;
+ private int numberInstances;
+ private int numInstanceSent = 0;
- private int groundTruthSamplingFrequency;
+ private int groundTruthSamplingFrequency;
- @Override
- public boolean process(ContentEvent event) {
- // TODO: possible refactor of the super-interface implementation
- // of source processor does not need this method
- return false;
+ @Override
+ public boolean process(ContentEvent event) {
+ // TODO: possible refactor of the super-interface implementation
+ // of source processor does not need this method
+ return false;
+ }
+
+ @Override
+ public void onCreate(int id) {
+ logger.debug("Creating ClusteringSourceProcessor with id {}", id);
+ }
+
+ @Override
+ public Processor newProcessor(Processor p) {
+ ClusteringEntranceProcessor newProcessor = new ClusteringEntranceProcessor();
+ ClusteringEntranceProcessor originProcessor = (ClusteringEntranceProcessor) p;
+ if (originProcessor.getStreamSource() != null) {
+ newProcessor.setStreamSource(originProcessor.getStreamSource().getStream());
+ }
+ return newProcessor;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (!isFinished());
+ }
+
+ @Override
+ public boolean isFinished() {
+ return (!streamSource.hasMoreInstances() || (numberInstances >= 0 && numInstanceSent >= numberInstances));
+ }
+
+ // /**
+ // * Method to send instances via input stream
+ // *
+ // * @param inputStream
+ // * @param numberInstances
+ // */
+ // public void sendInstances(Stream inputStream, Stream evaluationStream, int
+ // numberInstances, double samplingThreshold) {
+ // int numInstanceSent = 0;
+ // this.samplingThreshold = samplingThreshold;
+ // while (streamSource.hasMoreInstances() && numInstanceSent <
+ // numberInstances) {
+ // numInstanceSent++;
+ // DataPoint nextDataPoint = new DataPoint(nextInstance(), numInstanceSent);
+ // ClusteringContentEvent contentEvent = new
+ // ClusteringContentEvent(numInstanceSent, nextDataPoint);
+ // inputStream.put(contentEvent);
+ // sendPointsAndGroundTruth(streamSource, evaluationStream, numInstanceSent,
+ // nextDataPoint);
+ // }
+ //
+ // sendEndEvaluationInstance(inputStream);
+ // }
+
+ public double getSamplingThreshold() {
+ return samplingThreshold;
+ }
+
+ public void setSamplingThreshold(double samplingThreshold) {
+ this.samplingThreshold = samplingThreshold;
+ }
+
+ public int getGroundTruthSamplingFrequency() {
+ return groundTruthSamplingFrequency;
+ }
+
+ public void setGroundTruthSamplingFrequency(int groundTruthSamplingFrequency) {
+ this.groundTruthSamplingFrequency = groundTruthSamplingFrequency;
+ }
+
+ public StreamSource getStreamSource() {
+ return streamSource;
+ }
+
+ public void setStreamSource(InstanceStream stream) {
+ if (stream instanceof AbstractOptionHandler) {
+ ((AbstractOptionHandler) (stream)).prepareForUse();
}
- @Override
- public void onCreate(int id) {
- logger.debug("Creating ClusteringSourceProcessor with id {}", id);
+ this.streamSource = new StreamSource(stream);
+ firstInstance = streamSource.nextInstance().getData();
+ }
+
+ public Instances getDataset() {
+ return firstInstance.dataset();
+ }
+
+ private Instance nextInstance() {
+ if (this.isInited) {
+ return streamSource.nextInstance().getData();
+ } else {
+ this.isInited = true;
+ return firstInstance;
}
+ }
- @Override
- public Processor newProcessor(Processor p) {
- ClusteringEntranceProcessor newProcessor = new ClusteringEntranceProcessor();
- ClusteringEntranceProcessor originProcessor = (ClusteringEntranceProcessor) p;
- if (originProcessor.getStreamSource() != null) {
- newProcessor.setStreamSource(originProcessor.getStreamSource().getStream());
- }
- return newProcessor;
- }
+ // private void sendEndEvaluationInstance(Stream inputStream) {
+ // ClusteringContentEvent contentEvent = new ClusteringContentEvent(-1,
+ // firstInstance);
+ // contentEvent.setLast(true);
+ // inputStream.put(contentEvent);
+ // }
- @Override
- public boolean hasNext() {
- return (!isFinished());
- }
+ // private void sendPointsAndGroundTruth(StreamSource sourceStream, Stream
+ // evaluationStream, int numInstanceSent, DataPoint nextDataPoint) {
+ // boolean sendEvent = false;
+ // DataPoint instance = null;
+ // Clustering gtClustering = null;
+ // int samplingFrequency = ((ClusteringStream)
+ // sourceStream.getStream()).getDecayHorizon();
+ // if (random.nextDouble() < samplingThreshold) {
+ // // Add instance
+ // sendEvent = true;
+ // instance = nextDataPoint;
+ // }
+ // if (numInstanceSent > 0 && numInstanceSent % samplingFrequency == 0) {
+ // // Add GroundTruth
+ // sendEvent = true;
+ // gtClustering = ((RandomRBFGeneratorEvents)
+ // sourceStream.getStream()).getGeneratingClusters();
+ // }
+ // if (sendEvent == true) {
+ // ClusteringEvaluationContentEvent evalEvent;
+ // evalEvent = new ClusteringEvaluationContentEvent(gtClustering, instance,
+ // false);
+ // evaluationStream.put(evalEvent);
+ // }
+ // }
- @Override
- public boolean isFinished() {
- return (!streamSource.hasMoreInstances() || (numberInstances >= 0 && numInstanceSent >= numberInstances));
- }
+ public void setMaxNumInstances(int value) {
+ numberInstances = value;
+ }
- // /**
- // * Method to send instances via input stream
- // *
- // * @param inputStream
- // * @param numberInstances
- // */
- // public void sendInstances(Stream inputStream, Stream evaluationStream, int numberInstances, double samplingThreshold) {
- // int numInstanceSent = 0;
- // this.samplingThreshold = samplingThreshold;
- // while (streamSource.hasMoreInstances() && numInstanceSent < numberInstances) {
- // numInstanceSent++;
- // DataPoint nextDataPoint = new DataPoint(nextInstance(), numInstanceSent);
- // ClusteringContentEvent contentEvent = new ClusteringContentEvent(numInstanceSent, nextDataPoint);
- // inputStream.put(contentEvent);
- // sendPointsAndGroundTruth(streamSource, evaluationStream, numInstanceSent, nextDataPoint);
- // }
- //
- // sendEndEvaluationInstance(inputStream);
- // }
+ public int getMaxNumInstances() {
+ return this.numberInstances;
+ }
- public double getSamplingThreshold() {
- return samplingThreshold;
- }
+ @Override
+ public ContentEvent nextEvent() {
- public void setSamplingThreshold(double samplingThreshold) {
- this.samplingThreshold = samplingThreshold;
- }
-
-
-
- public int getGroundTruthSamplingFrequency() {
- return groundTruthSamplingFrequency;
- }
-
- public void setGroundTruthSamplingFrequency(int groundTruthSamplingFrequency) {
- this.groundTruthSamplingFrequency = groundTruthSamplingFrequency;
- }
-
- public StreamSource getStreamSource() {
- return streamSource;
- }
-
- public void setStreamSource(InstanceStream stream) {
- if (stream instanceof AbstractOptionHandler) {
- ((AbstractOptionHandler) (stream)).prepareForUse();
- }
-
- this.streamSource = new StreamSource(stream);
- firstInstance = streamSource.nextInstance().getData();
- }
-
- public Instances getDataset() {
- return firstInstance.dataset();
- }
-
- private Instance nextInstance() {
- if (this.isInited) {
- return streamSource.nextInstance().getData();
- } else {
- this.isInited = true;
- return firstInstance;
- }
- }
-
- // private void sendEndEvaluationInstance(Stream inputStream) {
- // ClusteringContentEvent contentEvent = new ClusteringContentEvent(-1, firstInstance);
- // contentEvent.setLast(true);
- // inputStream.put(contentEvent);
- // }
-
- // private void sendPointsAndGroundTruth(StreamSource sourceStream, Stream evaluationStream, int numInstanceSent, DataPoint nextDataPoint) {
// boolean sendEvent = false;
// DataPoint instance = null;
// Clustering gtClustering = null;
- // int samplingFrequency = ((ClusteringStream) sourceStream.getStream()).getDecayHorizon();
+ // int samplingFrequency = ((ClusteringStream)
+ // sourceStream.getStream()).getDecayHorizon();
// if (random.nextDouble() < samplingThreshold) {
// // Add instance
// sendEvent = true;
@@ -174,68 +214,52 @@
// if (numInstanceSent > 0 && numInstanceSent % samplingFrequency == 0) {
// // Add GroundTruth
// sendEvent = true;
- // gtClustering = ((RandomRBFGeneratorEvents) sourceStream.getStream()).getGeneratingClusters();
+ // gtClustering = ((RandomRBFGeneratorEvents)
+ // sourceStream.getStream()).getGeneratingClusters();
// }
// if (sendEvent == true) {
// ClusteringEvaluationContentEvent evalEvent;
- // evalEvent = new ClusteringEvaluationContentEvent(gtClustering, instance, false);
+ // evalEvent = new ClusteringEvaluationContentEvent(gtClustering, instance,
+ // false);
// evaluationStream.put(evalEvent);
// }
- // }
- public void setMaxNumInstances(int value) {
- numberInstances = value;
- }
-
- public int getMaxNumInstances() {
- return this.numberInstances;
- }
-
- @Override
- public ContentEvent nextEvent() {
-
- // boolean sendEvent = false;
- // DataPoint instance = null;
- // Clustering gtClustering = null;
- // int samplingFrequency = ((ClusteringStream) sourceStream.getStream()).getDecayHorizon();
- // if (random.nextDouble() < samplingThreshold) {
- // // Add instance
- // sendEvent = true;
- // instance = nextDataPoint;
- // }
- // if (numInstanceSent > 0 && numInstanceSent % samplingFrequency == 0) {
- // // Add GroundTruth
- // sendEvent = true;
- // gtClustering = ((RandomRBFGeneratorEvents) sourceStream.getStream()).getGeneratingClusters();
- // }
- // if (sendEvent == true) {
- // ClusteringEvaluationContentEvent evalEvent;
- // evalEvent = new ClusteringEvaluationContentEvent(gtClustering, instance, false);
- // evaluationStream.put(evalEvent);
- // }
-
- groundTruthSamplingFrequency = ((ClusteringStream) streamSource.getStream()).getDecayHorizon(); // FIXME should it be taken from the ClusteringEvaluation -f option instead?
- if (isFinished()) {
- // send ending event
- ClusteringContentEvent contentEvent = new ClusteringContentEvent(-1, firstInstance);
- contentEvent.setLast(true);
- return contentEvent;
- } else {
- DataPoint nextDataPoint = new DataPoint(nextInstance(), numInstanceSent);
- numInstanceSent++;
- if (numInstanceSent % groundTruthSamplingFrequency == 0) {
- // TODO implement an interface ClusteringGroundTruth with a getGeneratingClusters() method, check if the source implements the interface
- // send a clustering evaluation event for external measures (distance from the gt clusters)
- Clustering gtClustering = ((RandomRBFGeneratorEvents) streamSource.getStream()).getGeneratingClusters();
- return new ClusteringEvaluationContentEvent(gtClustering, nextDataPoint, false);
- } else {
- ClusteringContentEvent contentEvent = new ClusteringContentEvent(numInstanceSent, nextDataPoint);
- if (random.nextDouble() < samplingThreshold) {
- // send a clustering content event for internal measures (cohesion, separation)
- contentEvent.setSample(true);
- }
- return contentEvent;
- }
+ groundTruthSamplingFrequency = ((ClusteringStream) streamSource.getStream()).getDecayHorizon(); // FIXME
+ // should
+ // it
+ // be
+ // taken
+ // from
+ // the
+ // ClusteringEvaluation
+ // -f
+ // option
+ // instead?
+ if (isFinished()) {
+ // send ending event
+ ClusteringContentEvent contentEvent = new ClusteringContentEvent(-1, firstInstance);
+ contentEvent.setLast(true);
+ return contentEvent;
+ } else {
+ DataPoint nextDataPoint = new DataPoint(nextInstance(), numInstanceSent);
+ numInstanceSent++;
+ if (numInstanceSent % groundTruthSamplingFrequency == 0) {
+ // TODO implement an interface ClusteringGroundTruth with a
+ // getGeneratingClusters() method, check if the source implements the
+ // interface
+ // send a clustering evaluation event for external measures (distance
+ // from the gt clusters)
+ Clustering gtClustering = ((RandomRBFGeneratorEvents) streamSource.getStream()).getGeneratingClusters();
+ return new ClusteringEvaluationContentEvent(gtClustering, nextDataPoint, false);
+ } else {
+ ClusteringContentEvent contentEvent = new ClusteringContentEvent(numInstanceSent, nextDataPoint);
+ if (random.nextDouble() < samplingThreshold) {
+ // send a clustering content event for internal measures (cohesion,
+ // separation)
+ contentEvent.setSample(true);
}
+ return contentEvent;
+ }
}
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/FileStream.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/FileStream.java
index c8004f5..548c0da 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/FileStream.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/FileStream.java
@@ -37,138 +37,139 @@
import com.yahoo.labs.samoa.streams.fs.FileStreamSource;
/**
- * InstanceStream for files
- * (Abstract class: subclass this class for different file formats)
+ * InstanceStream for files (Abstract class: subclass this class for different
+ * file formats)
+ *
* @author Casey
*/
public abstract class FileStream extends AbstractOptionHandler implements InstanceStream {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 3028905554604259130L;
+ private static final long serialVersionUID = 3028905554604259130L;
- public ClassOption sourceTypeOption = new ClassOption("sourceType",
- 's', "Source Type (HDFS, local FS)", FileStreamSource.class,
- "LocalFileStreamSource");
-
- protected transient FileStreamSource fileSource;
- protected transient Reader fileReader;
- protected Instances instances;
-
- protected boolean hitEndOfStream;
- private boolean hasStarted;
+ public ClassOption sourceTypeOption = new ClassOption("sourceType",
+ 's', "Source Type (HDFS, local FS)", FileStreamSource.class,
+ "LocalFileStreamSource");
- /*
- * Constructors
- */
- public FileStream() {
- this.hitEndOfStream = false;
- }
-
- /*
- * implement InstanceStream
- */
- @Override
- public InstancesHeader getHeader() {
- return new InstancesHeader(this.instances);
- }
+ protected transient FileStreamSource fileSource;
+ protected transient Reader fileReader;
+ protected Instances instances;
- @Override
- public long estimatedRemainingInstances() {
- return -1;
- }
+ protected boolean hitEndOfStream;
+ private boolean hasStarted;
- @Override
- public boolean hasMoreInstances() {
- return !this.hitEndOfStream;
- }
-
- @Override
- public InstanceExample nextInstance() {
- if (this.getLastInstanceRead() == null) {
- readNextInstanceFromStream();
- }
- InstanceExample prevInstance = this.getLastInstanceRead();
- readNextInstanceFromStream();
- return prevInstance;
- }
+ /*
+ * Constructors
+ */
+ public FileStream() {
+ this.hitEndOfStream = false;
+ }
- @Override
- public boolean isRestartable() {
- return true;
- }
+ /*
+ * implement InstanceStream
+ */
+ @Override
+ public InstancesHeader getHeader() {
+ return new InstancesHeader(this.instances);
+ }
- @Override
- public void restart() {
- reset();
- hasStarted = false;
- }
+ @Override
+ public long estimatedRemainingInstances() {
+ return -1;
+ }
- protected void reset() {
- try {
- if (this.fileReader != null)
- this.fileReader.close();
-
- fileSource.reset();
- }
- catch (IOException ioe) {
- throw new RuntimeException("FileStream restart failed.", ioe);
- }
-
- if (!getNextFileReader()) {
- hitEndOfStream = true;
- throw new RuntimeException("FileStream is empty.");
- }
-
- this.instances = new Instances(this.fileReader, 1, -1);
- this.instances.setClassIndex(this.instances.numAttributes() - 1);
- }
-
- protected boolean getNextFileReader() {
- if (this.fileReader != null)
- try {
- this.fileReader.close();
- } catch (IOException ioe) {
- ioe.printStackTrace();
- }
-
- InputStream inputStream = this.fileSource.getNextInputStream();
- if (inputStream == null)
- return false;
+ @Override
+ public boolean hasMoreInstances() {
+ return !this.hitEndOfStream;
+ }
- this.fileReader = new BufferedReader(new InputStreamReader(inputStream));
- return true;
- }
-
- protected boolean readNextInstanceFromStream() {
- if (!hasStarted) {
- this.reset();
- hasStarted = true;
- }
-
- while (true) {
- if (readNextInstanceFromFile()) return true;
+ @Override
+ public InstanceExample nextInstance() {
+ if (this.getLastInstanceRead() == null) {
+ readNextInstanceFromStream();
+ }
+ InstanceExample prevInstance = this.getLastInstanceRead();
+ readNextInstanceFromStream();
+ return prevInstance;
+ }
- if (!getNextFileReader()) {
- this.hitEndOfStream = true;
- return false;
- }
- }
- }
-
- /**
- * Read next instance from the current file and assign it to
- * lastInstanceRead.
- * @return true if it was able to read next instance and
- * false if it was at the end of the file
- */
- protected abstract boolean readNextInstanceFromFile();
-
- protected abstract InstanceExample getLastInstanceRead();
-
- @Override
- public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
- this.fileSource = sourceTypeOption.getValue();
- this.hasStarted = false;
- }
+ @Override
+ public boolean isRestartable() {
+ return true;
+ }
+
+ @Override
+ public void restart() {
+ reset();
+ hasStarted = false;
+ }
+
+ protected void reset() {
+ try {
+ if (this.fileReader != null)
+ this.fileReader.close();
+
+ fileSource.reset();
+ } catch (IOException ioe) {
+ throw new RuntimeException("FileStream restart failed.", ioe);
+ }
+
+ if (!getNextFileReader()) {
+ hitEndOfStream = true;
+ throw new RuntimeException("FileStream is empty.");
+ }
+
+ this.instances = new Instances(this.fileReader, 1, -1);
+ this.instances.setClassIndex(this.instances.numAttributes() - 1);
+ }
+
+ protected boolean getNextFileReader() {
+ if (this.fileReader != null)
+ try {
+ this.fileReader.close();
+ } catch (IOException ioe) {
+ ioe.printStackTrace();
+ }
+
+ InputStream inputStream = this.fileSource.getNextInputStream();
+ if (inputStream == null)
+ return false;
+
+ this.fileReader = new BufferedReader(new InputStreamReader(inputStream));
+ return true;
+ }
+
+ protected boolean readNextInstanceFromStream() {
+ if (!hasStarted) {
+ this.reset();
+ hasStarted = true;
+ }
+
+ while (true) {
+ if (readNextInstanceFromFile())
+ return true;
+
+ if (!getNextFileReader()) {
+ this.hitEndOfStream = true;
+ return false;
+ }
+ }
+ }
+
+ /**
+ * Read next instance from the current file and assign it to lastInstanceRead.
+ *
+ * @return true if it was able to read next instance and false if it was at
+ * the end of the file
+ */
+ protected abstract boolean readNextInstanceFromFile();
+
+ protected abstract InstanceExample getLastInstanceRead();
+
+ @Override
+ public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
+ this.fileSource = sourceTypeOption.getValue();
+ this.hasStarted = false;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/PrequentialSourceProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/PrequentialSourceProcessor.java
index e9a5aa1..f9884e9 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/PrequentialSourceProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/PrequentialSourceProcessor.java
@@ -38,192 +38,196 @@
import com.yahoo.labs.samoa.moa.streams.InstanceStream;
/**
- * Prequential Source Processor is the processor for Prequential Evaluation Task.
+ * Prequential Source Processor is the processor for Prequential Evaluation
+ * Task.
*
* @author Arinto Murdopo
*
*/
public final class PrequentialSourceProcessor implements EntranceProcessor {
- private static final long serialVersionUID = 4169053337917578558L;
+ private static final long serialVersionUID = 4169053337917578558L;
- private static final Logger logger = LoggerFactory.getLogger(PrequentialSourceProcessor.class);
- private boolean isInited = false;
- private StreamSource streamSource;
- private Instance firstInstance;
- private int numberInstances;
- private int numInstanceSent = 0;
+ private static final Logger logger = LoggerFactory.getLogger(PrequentialSourceProcessor.class);
+ private boolean isInited = false;
+ private StreamSource streamSource;
+ private Instance firstInstance;
+ private int numberInstances;
+ private int numInstanceSent = 0;
- protected InstanceStream sourceStream;
-
- /*
- * ScheduledExecutorService to schedule sending events after each delay interval.
- * It is expected to have only one event in the queue at a time, so we need only
- * one thread in the pool.
- */
- private transient ScheduledExecutorService timer;
- private transient ScheduledFuture<?> schedule = null;
- private int readyEventIndex = 1; // No waiting for the first event
- private int delay = 0;
- private int batchSize = 1;
- private boolean finished = false;
+ protected InstanceStream sourceStream;
- @Override
- public boolean process(ContentEvent event) {
- // TODO: possible refactor of the super-interface implementation
- // of source processor does not need this method
- return false;
+ /*
+ * ScheduledExecutorService to schedule sending events after each delay
+ * interval. It is expected to have only one event in the queue at a time, so
+ * we need only one thread in the pool.
+ */
+ private transient ScheduledExecutorService timer;
+ private transient ScheduledFuture<?> schedule = null;
+ private int readyEventIndex = 1; // No waiting for the first event
+ private int delay = 0;
+ private int batchSize = 1;
+ private boolean finished = false;
+
+ @Override
+ public boolean process(ContentEvent event) {
+ // TODO: possible refactor of the super-interface implementation
+ // of source processor does not need this method
+ return false;
+ }
+
+ @Override
+ public boolean isFinished() {
+ return finished;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return !isFinished() && (delay <= 0 || numInstanceSent < readyEventIndex);
+ }
+
+ private boolean hasReachedEndOfStream() {
+ return (!streamSource.hasMoreInstances() || (numberInstances >= 0 && numInstanceSent >= numberInstances));
+ }
+
+ @Override
+ public ContentEvent nextEvent() {
+ InstanceContentEvent contentEvent = null;
+ if (hasReachedEndOfStream()) {
+ contentEvent = new InstanceContentEvent(-1, firstInstance, false, true);
+ contentEvent.setLast(true);
+ // set finished status _after_ tagging last event
+ finished = true;
}
-
- @Override
- public boolean isFinished() {
- return finished;
+ else if (hasNext()) {
+ numInstanceSent++;
+ contentEvent = new InstanceContentEvent(numInstanceSent, nextInstance(), true, true);
+
+ // first call to this method will trigger the timer
+ if (schedule == null && delay > 0) {
+ schedule = timer.scheduleWithFixedDelay(new DelayTimeoutHandler(this), delay, delay,
+ TimeUnit.MICROSECONDS);
+ }
+ }
+ return contentEvent;
+ }
+
+ private void increaseReadyEventIndex() {
+ readyEventIndex += batchSize;
+ // if we exceed the max, cancel the timer
+ if (schedule != null && isFinished()) {
+ schedule.cancel(false);
+ }
+ }
+
+ @Override
+ public void onCreate(int id) {
+ initStreamSource(sourceStream);
+ timer = Executors.newScheduledThreadPool(1);
+ logger.debug("Creating PrequentialSourceProcessor with id {}", id);
+ }
+
+ @Override
+ public Processor newProcessor(Processor p) {
+ PrequentialSourceProcessor newProcessor = new PrequentialSourceProcessor();
+ PrequentialSourceProcessor originProcessor = (PrequentialSourceProcessor) p;
+ if (originProcessor.getStreamSource() != null) {
+ newProcessor.setStreamSource(originProcessor.getStreamSource().getStream());
+ }
+ return newProcessor;
+ }
+
+ // /**
+ // * Method to send instances via input stream
+ // *
+ // * @param inputStream
+ // * @param numberInstances
+ // */
+ // public void sendInstances(Stream inputStream, int numberInstances) {
+ // int numInstanceSent = 0;
+ // initStreamSource(sourceStream);
+ //
+ // while (streamSource.hasMoreInstances() && numInstanceSent <
+ // numberInstances) {
+ // numInstanceSent++;
+ // InstanceContentEvent contentEvent = new
+ // InstanceContentEvent(numInstanceSent, nextInstance(), true, true);
+ // inputStream.put(contentEvent);
+ // }
+ //
+ // sendEndEvaluationInstance(inputStream);
+ // }
+
+ public StreamSource getStreamSource() {
+ return streamSource;
+ }
+
+ public void setStreamSource(InstanceStream stream) {
+ this.sourceStream = stream;
+ }
+
+ public Instances getDataset() {
+ if (firstInstance == null) {
+ initStreamSource(sourceStream);
+ }
+ return firstInstance.dataset();
+ }
+
+ private Instance nextInstance() {
+ if (this.isInited) {
+ return streamSource.nextInstance().getData();
+ } else {
+ this.isInited = true;
+ return firstInstance;
+ }
+ }
+
+ // private void sendEndEvaluationInstance(Stream inputStream) {
+ // InstanceContentEvent contentEvent = new InstanceContentEvent(-1,
+ // firstInstance, false, true);
+ // contentEvent.setLast(true);
+ // inputStream.put(contentEvent);
+ // }
+
+ private void initStreamSource(InstanceStream stream) {
+ if (stream instanceof AbstractOptionHandler) {
+ ((AbstractOptionHandler) (stream)).prepareForUse();
}
- @Override
- public boolean hasNext() {
- return !isFinished() && (delay <= 0 || numInstanceSent < readyEventIndex);
+ this.streamSource = new StreamSource(stream);
+ firstInstance = streamSource.nextInstance().getData();
+ }
+
+ public void setMaxNumInstances(int value) {
+ numberInstances = value;
+ }
+
+ public int getMaxNumInstances() {
+ return this.numberInstances;
+ }
+
+ public void setSourceDelay(int delay) {
+ this.delay = delay;
+ }
+
+ public int getSourceDelay() {
+ return this.delay;
+ }
+
+ public void setDelayBatchSize(int batch) {
+ this.batchSize = batch;
+ }
+
+ private class DelayTimeoutHandler implements Runnable {
+
+ private PrequentialSourceProcessor processor;
+
+ public DelayTimeoutHandler(PrequentialSourceProcessor processor) {
+ this.processor = processor;
}
- private boolean hasReachedEndOfStream() {
- return (!streamSource.hasMoreInstances() || (numberInstances >= 0 && numInstanceSent >= numberInstances));
+ public void run() {
+ processor.increaseReadyEventIndex();
}
-
- @Override
- public ContentEvent nextEvent() {
- InstanceContentEvent contentEvent = null;
- if (hasReachedEndOfStream()) {
- contentEvent = new InstanceContentEvent(-1, firstInstance, false, true);
- contentEvent.setLast(true);
- // set finished status _after_ tagging last event
- finished = true;
- }
- else if (hasNext()) {
- numInstanceSent++;
- contentEvent = new InstanceContentEvent(numInstanceSent, nextInstance(), true, true);
-
- // first call to this method will trigger the timer
- if (schedule == null && delay > 0) {
- schedule = timer.scheduleWithFixedDelay(new DelayTimeoutHandler(this), delay, delay,
- TimeUnit.MICROSECONDS);
- }
- }
- return contentEvent;
- }
-
- private void increaseReadyEventIndex() {
- readyEventIndex+=batchSize;
- // if we exceed the max, cancel the timer
- if (schedule != null && isFinished()) {
- schedule.cancel(false);
- }
- }
-
- @Override
- public void onCreate(int id) {
- initStreamSource(sourceStream);
- timer = Executors.newScheduledThreadPool(1);
- logger.debug("Creating PrequentialSourceProcessor with id {}", id);
- }
-
- @Override
- public Processor newProcessor(Processor p) {
- PrequentialSourceProcessor newProcessor = new PrequentialSourceProcessor();
- PrequentialSourceProcessor originProcessor = (PrequentialSourceProcessor) p;
- if (originProcessor.getStreamSource() != null) {
- newProcessor.setStreamSource(originProcessor.getStreamSource().getStream());
- }
- return newProcessor;
- }
-
-// /**
-// * Method to send instances via input stream
-// *
-// * @param inputStream
-// * @param numberInstances
-// */
-// public void sendInstances(Stream inputStream, int numberInstances) {
-// int numInstanceSent = 0;
-// initStreamSource(sourceStream);
-//
-// while (streamSource.hasMoreInstances() && numInstanceSent < numberInstances) {
-// numInstanceSent++;
-// InstanceContentEvent contentEvent = new InstanceContentEvent(numInstanceSent, nextInstance(), true, true);
-// inputStream.put(contentEvent);
-// }
-//
-// sendEndEvaluationInstance(inputStream);
-// }
-
- public StreamSource getStreamSource() {
- return streamSource;
- }
-
- public void setStreamSource(InstanceStream stream) {
- this.sourceStream = stream;
- }
-
- public Instances getDataset() {
- if (firstInstance == null) {
- initStreamSource(sourceStream);
- }
- return firstInstance.dataset();
- }
-
- private Instance nextInstance() {
- if (this.isInited) {
- return streamSource.nextInstance().getData();
- } else {
- this.isInited = true;
- return firstInstance;
- }
- }
-
-// private void sendEndEvaluationInstance(Stream inputStream) {
-// InstanceContentEvent contentEvent = new InstanceContentEvent(-1, firstInstance, false, true);
-// contentEvent.setLast(true);
-// inputStream.put(contentEvent);
-// }
-
- private void initStreamSource(InstanceStream stream) {
- if (stream instanceof AbstractOptionHandler) {
- ((AbstractOptionHandler) (stream)).prepareForUse();
- }
-
- this.streamSource = new StreamSource(stream);
- firstInstance = streamSource.nextInstance().getData();
- }
-
- public void setMaxNumInstances(int value) {
- numberInstances = value;
- }
-
- public int getMaxNumInstances() {
- return this.numberInstances;
- }
-
- public void setSourceDelay(int delay) {
- this.delay = delay;
- }
-
- public int getSourceDelay() {
- return this.delay;
- }
-
- public void setDelayBatchSize(int batch) {
- this.batchSize = batch;
- }
-
- private class DelayTimeoutHandler implements Runnable {
-
- private PrequentialSourceProcessor processor;
-
- public DelayTimeoutHandler(PrequentialSourceProcessor processor) {
- this.processor = processor;
- }
-
- public void run() {
- processor.increaseReadyEventIndex();
- }
- }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/StreamSource.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/StreamSource.java
index 453d02d..4eca28c 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/StreamSource.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/StreamSource.java
@@ -31,60 +31,62 @@
/**
* The Class StreamSource.
*/
-public class StreamSource implements java.io.Serializable{
+public class StreamSource implements java.io.Serializable {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 3974668694861231236L;
+ private static final long serialVersionUID = 3974668694861231236L;
- /**
- * Instantiates a new stream source.
- *
- * @param stream the stream
- */
- public StreamSource(InstanceStream stream) {
- super();
- this.stream = stream;
- }
+ /**
+ * Instantiates a new stream source.
+ *
+ * @param stream
+ * the stream
+ */
+ public StreamSource(InstanceStream stream) {
+ super();
+ this.stream = stream;
+ }
- /** The stream. */
- protected InstanceStream stream;
+ /** The stream. */
+ protected InstanceStream stream;
- /**
- * Gets the stream.
- *
- * @return the stream
- */
- public InstanceStream getStream() {
- return stream;
- }
+ /**
+ * Gets the stream.
+ *
+ * @return the stream
+ */
+ public InstanceStream getStream() {
+ return stream;
+ }
- /**
- * Next instance.
- *
- * @return the instance
- */
- public Example<Instance> nextInstance() {
- return stream.nextInstance();
- }
+ /**
+ * Next instance.
+ *
+ * @return the instance
+ */
+ public Example<Instance> nextInstance() {
+ return stream.nextInstance();
+ }
- /**
- * Sets the stream.
- *
- * @param stream the new stream
- */
- public void setStream(InstanceStream stream) {
- this.stream = stream;
- }
+ /**
+ * Sets the stream.
+ *
+ * @param stream
+ * the new stream
+ */
+ public void setStream(InstanceStream stream) {
+ this.stream = stream;
+ }
- /**
- * Checks for more instances.
- *
- * @return true, if successful
- */
- public boolean hasMoreInstances() {
- return this.stream.hasMoreInstances();
- }
+ /**
+ * Checks for more instances.
+ *
+ * @return true, if successful
+ */
+ public boolean hasMoreInstances() {
+ return this.stream.hasMoreInstances();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/StreamSourceProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/StreamSourceProcessor.java
index 2e66e4b..980802f 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/StreamSourceProcessor.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/StreamSourceProcessor.java
@@ -39,147 +39,157 @@
* The Class StreamSourceProcessor.
*/
public class StreamSourceProcessor implements Processor {
-
- /** The Constant logger. */
- private static final Logger logger = LoggerFactory
- .getLogger(StreamSourceProcessor.class);
- /**
+ /** The Constant logger. */
+ private static final Logger logger = LoggerFactory
+ .getLogger(StreamSourceProcessor.class);
+
+ /**
*
*/
- private static final long serialVersionUID = -204182279475432739L;
+ private static final long serialVersionUID = -204182279475432739L;
- /** The stream source. */
- private StreamSource streamSource;
+ /** The stream source. */
+ private StreamSource streamSource;
- /**
- * Gets the stream source.
- *
- * @return the stream source
- */
- public StreamSource getStreamSource() {
- return streamSource;
- }
+ /**
+ * Gets the stream source.
+ *
+ * @return the stream source
+ */
+ public StreamSource getStreamSource() {
+ return streamSource;
+ }
- /**
- * Sets the stream source.
- *
- * @param stream the new stream source
- */
- public void setStreamSource(InstanceStream stream) {
- this.streamSource = new StreamSource(stream);
- firstInstance = streamSource.nextInstance().getData();
- }
+ /**
+ * Sets the stream source.
+ *
+ * @param stream
+ * the new stream source
+ */
+ public void setStreamSource(InstanceStream stream) {
+ this.streamSource = new StreamSource(stream);
+ firstInstance = streamSource.nextInstance().getData();
+ }
- /** The number instances sent. */
- private long numberInstancesSent = 0;
+ /** The number instances sent. */
+ private long numberInstancesSent = 0;
- /**
- * Send instances.
- * @param inputStream the input stream
- * @param numberInstances the number instances
- * @param isTraining the is training
- */
- public void sendInstances(Stream inputStream,
- int numberInstances, boolean isTraining, boolean isTesting) {
- int numberSamples = 0;
+ /**
+ * Send instances.
+ *
+ * @param inputStream
+ * the input stream
+ * @param numberInstances
+ * the number instances
+ * @param isTraining
+ * the is training
+ */
+ public void sendInstances(Stream inputStream,
+ int numberInstances, boolean isTraining, boolean isTesting) {
+ int numberSamples = 0;
- while (streamSource.hasMoreInstances()
- && numberSamples < numberInstances) {
-
- numberSamples++;
- numberInstancesSent++;
- InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
- numberInstancesSent, nextInstance(), isTraining, isTesting);
-
-
- inputStream.put(instanceContentEvent);
- }
+ while (streamSource.hasMoreInstances()
+ && numberSamples < numberInstances) {
- InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
- numberInstancesSent, null, isTraining, isTesting);
- instanceContentEvent.setLast(true);
- inputStream.put(instanceContentEvent);
- }
+ numberSamples++;
+ numberInstancesSent++;
+ InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
+ numberInstancesSent, nextInstance(), isTraining, isTesting);
- /**
- * Send end evaluation instance.
- *
- * @param inputStream the input stream
- */
- public void sendEndEvaluationInstance(Stream inputStream) {
- InstanceContentEvent instanceContentEvent = new InstanceContentEvent(-1, firstInstance,false, true);
- inputStream.put(instanceContentEvent);
- }
+ inputStream.put(instanceContentEvent);
+ }
- /**
- * Next instance.
- *
- * @return the instance
- */
- protected Instance nextInstance() {
- if (this.isInited) {
- return streamSource.nextInstance().getData();
- } else {
- this.isInited = true;
- return firstInstance;
- }
- }
+ InstanceContentEvent instanceContentEvent = new InstanceContentEvent(
+ numberInstancesSent, null, isTraining, isTesting);
+ instanceContentEvent.setLast(true);
+ inputStream.put(instanceContentEvent);
+ }
- /** The is inited. */
- protected boolean isInited = false;
-
- /** The first instance. */
- protected Instance firstInstance;
+ /**
+ * Send end evaluation instance.
+ *
+ * @param inputStream
+ * the input stream
+ */
+ public void sendEndEvaluationInstance(Stream inputStream) {
+ InstanceContentEvent instanceContentEvent = new InstanceContentEvent(-1, firstInstance, false, true);
+ inputStream.put(instanceContentEvent);
+ }
- //@Override
- /**
- * On remove.
- */
- protected void onRemove() {
- }
+ /**
+ * Next instance.
+ *
+ * @return the instance
+ */
+ protected Instance nextInstance() {
+ if (this.isInited) {
+ return streamSource.nextInstance().getData();
+ } else {
+ this.isInited = true;
+ return firstInstance;
+ }
+ }
- /* (non-Javadoc)
- * @see samoa.core.Processor#onCreate(int)
- */
- @Override
- public void onCreate(int id) {
- // TODO Auto-generated method stub
- }
+ /** The is inited. */
+ protected boolean isInited = false;
- /* (non-Javadoc)
- * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
- */
- @Override
- public Processor newProcessor(Processor sourceProcessor) {
-// StreamSourceProcessor newProcessor = new StreamSourceProcessor();
-// StreamSourceProcessor originProcessor = (StreamSourceProcessor) sourceProcessor;
-// if (originProcessor.getStreamSource() != null){
-// newProcessor.setStreamSource(originProcessor.getStreamSource().getStream());
-// }
- //return newProcessor;
- return null;
- }
+ /** The first instance. */
+ protected Instance firstInstance;
- /**
- * On event.
- *
- * @param event the event
- * @return true, if successful
- */
- @Override
- public boolean process(ContentEvent event) {
- return false;
- }
-
-
- /**
- * Gets the dataset.
- *
- * @return the dataset
- */
- public Instances getDataset() {
- return firstInstance.dataset();
- }
+ // @Override
+ /**
+ * On remove.
+ */
+ protected void onRemove() {
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.Processor#onCreate(int)
+ */
+ @Override
+ public void onCreate(int id) {
+ // TODO Auto-generated method stub
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
+ */
+ @Override
+ public Processor newProcessor(Processor sourceProcessor) {
+ // StreamSourceProcessor newProcessor = new StreamSourceProcessor();
+ // StreamSourceProcessor originProcessor = (StreamSourceProcessor)
+ // sourceProcessor;
+ // if (originProcessor.getStreamSource() != null){
+ // newProcessor.setStreamSource(originProcessor.getStreamSource().getStream());
+ // }
+ // return newProcessor;
+ return null;
+ }
+
+ /**
+ * On event.
+ *
+ * @param event
+ * the event
+ * @return true, if successful
+ */
+ @Override
+ public boolean process(ContentEvent event) {
+ return false;
+ }
+
+ /**
+ * Gets the dataset.
+ *
+ * @return the dataset
+ */
+ public Instances getDataset() {
+ return firstInstance.dataset();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/FileStreamSource.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/FileStreamSource.java
index 25541e2..6d741f9 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/FileStreamSource.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/FileStreamSource.java
@@ -26,42 +26,41 @@
/**
* An interface for FileStream's source (Local FS, HDFS,...)
+ *
* @author Casey
*/
public interface FileStreamSource extends Serializable {
- /**
- * Init the source with file/directory path and file extension
- * @param path
- * File or directory path
- * @param ext
- * File extension to be used to filter files in a directory.
- * If null, all files in the directory are accepted.
- */
- public void init(String path, String ext);
-
- /**
- * Reset the source
- */
- public void reset() throws IOException;
-
- /**
- * Retrieve InputStream for next file.
- * This method will return null if we are at the last file
- * in the list.
- *
- * @return InputStream for next file in the list
- */
- public InputStream getNextInputStream();
-
- /**
- * Retrieve InputStream for current file.
- * The "current pointer" is moved forward
- * with getNextInputStream method. So if there was no
- * invocation of getNextInputStream, this method will
- * return null.
- *
- * @return InputStream for current file in the list
- */
- public InputStream getCurrentInputStream();
+ /**
+ * Init the source with file/directory path and file extension
+ *
+ * @param path
+ * File or directory path
+ * @param ext
+ * File extension to be used to filter files in a directory. If null,
+ * all files in the directory are accepted.
+ */
+ public void init(String path, String ext);
+
+ /**
+ * Reset the source
+ */
+ public void reset() throws IOException;
+
+ /**
+ * Retrieve InputStream for next file. This method will return null if we are
+ * at the last file in the list.
+ *
+ * @return InputStream for next file in the list
+ */
+ public InputStream getNextInputStream();
+
+ /**
+ * Retrieve InputStream for current file. The "current pointer" is moved
+ * forward with getNextInputStream method. So if there was no invocation of
+ * getNextInputStream, this method will return null.
+ *
+ * @return InputStream for current file in the list
+ */
+ public InputStream getCurrentInputStream();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/HDFSFileStreamSource.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/HDFSFileStreamSource.java
index 079423c..a4f885e 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/HDFSFileStreamSource.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/HDFSFileStreamSource.java
@@ -34,117 +34,117 @@
/**
* Source for FileStream for HDFS files
+ *
* @author Casey
*/
public class HDFSFileStreamSource implements FileStreamSource {
-
- /**
+
+ /**
*
*/
- private static final long serialVersionUID = -3887354805787167400L;
-
- private transient InputStream fileStream;
- private transient Configuration config;
- private List<String> filePaths;
- private int currentIndex;
-
- public HDFSFileStreamSource(){
- this.currentIndex = -1;
- }
-
- public void init(String path, String ext) {
- this.init(this.getDefaultConfig(), path, ext);
- }
-
- public void init(Configuration config, String path, String ext) {
- this.config = config;
- this.filePaths = new ArrayList<String>();
- Path hdfsPath = new Path(path);
- FileSystem fs;
- try {
- fs = FileSystem.get(config);
- FileStatus fileStat = fs.getFileStatus(hdfsPath);
- if (fileStat.isDirectory()) {
- Path filterPath = hdfsPath;
- if (ext != null) {
- filterPath = new Path(path.toString(),"*."+ext);
- }
- else {
- filterPath = new Path(path.toString(),"*");
- }
- FileStatus[] filesInDir = fs.globStatus(filterPath);
- for (int i=0; i<filesInDir.length; i++) {
- if (filesInDir[i].isFile()) {
- filePaths.add(filesInDir[i].getPath().toString());
- }
- }
- }
- else {
- this.filePaths.add(path);
- }
- }
- catch(IOException ioe) {
- throw new RuntimeException("Failed getting list of files at:"+path,ioe);
- }
-
- this.currentIndex = -1;
- }
-
- private Configuration getDefaultConfig() {
- String hadoopHome = System.getenv("HADOOP_HOME");
- Configuration conf = new Configuration();
- if (hadoopHome != null) {
- java.nio.file.Path coreSitePath = FileSystems.getDefault().getPath(hadoopHome, "etc/hadoop/core-site.xml");
- java.nio.file.Path hdfsSitePath = FileSystems.getDefault().getPath(hadoopHome, "etc/hadoop/hdfs-site.xml");
- conf.addResource(new Path(coreSitePath.toAbsolutePath().toString()));
- conf.addResource(new Path(hdfsSitePath.toAbsolutePath().toString()));
- }
- return conf;
- }
-
- public void reset() throws IOException {
- this.currentIndex = -1;
- this.closeFileStream();
- }
+ private static final long serialVersionUID = -3887354805787167400L;
- private void closeFileStream() {
- IOUtils.closeStream(fileStream);
- }
+ private transient InputStream fileStream;
+ private transient Configuration config;
+ private List<String> filePaths;
+ private int currentIndex;
- public InputStream getNextInputStream() {
- this.closeFileStream();
- if (this.currentIndex >= (this.filePaths.size()-1)) return null;
-
- this.currentIndex++;
- String filePath = this.filePaths.get(currentIndex);
-
- Path hdfsPath = new Path(filePath);
- FileSystem fs;
- try {
- fs = FileSystem.get(config);
- fileStream = fs.open(hdfsPath);
- }
- catch(IOException ioe) {
- this.closeFileStream();
- throw new RuntimeException("Failed opening file:"+filePath,ioe);
- }
-
- return fileStream;
- }
+ public HDFSFileStreamSource() {
+ this.currentIndex = -1;
+ }
- public InputStream getCurrentInputStream() {
- return fileStream;
- }
-
- protected int getFilePathListSize() {
- if (filePaths != null)
- return filePaths.size();
- return 0;
- }
-
- protected String getFilePathAt(int index) {
- if (filePaths != null && filePaths.size() > index)
- return filePaths.get(index);
- return null;
- }
+ public void init(String path, String ext) {
+ this.init(this.getDefaultConfig(), path, ext);
+ }
+
+ public void init(Configuration config, String path, String ext) {
+ this.config = config;
+ this.filePaths = new ArrayList<String>();
+ Path hdfsPath = new Path(path);
+ FileSystem fs;
+ try {
+ fs = FileSystem.get(config);
+ FileStatus fileStat = fs.getFileStatus(hdfsPath);
+ if (fileStat.isDirectory()) {
+ Path filterPath = hdfsPath;
+ if (ext != null) {
+ filterPath = new Path(path.toString(), "*." + ext);
+ }
+ else {
+ filterPath = new Path(path.toString(), "*");
+ }
+ FileStatus[] filesInDir = fs.globStatus(filterPath);
+ for (int i = 0; i < filesInDir.length; i++) {
+ if (filesInDir[i].isFile()) {
+ filePaths.add(filesInDir[i].getPath().toString());
+ }
+ }
+ }
+ else {
+ this.filePaths.add(path);
+ }
+ } catch (IOException ioe) {
+ throw new RuntimeException("Failed getting list of files at:" + path, ioe);
+ }
+
+ this.currentIndex = -1;
+ }
+
+ private Configuration getDefaultConfig() {
+ String hadoopHome = System.getenv("HADOOP_HOME");
+ Configuration conf = new Configuration();
+ if (hadoopHome != null) {
+ java.nio.file.Path coreSitePath = FileSystems.getDefault().getPath(hadoopHome, "etc/hadoop/core-site.xml");
+ java.nio.file.Path hdfsSitePath = FileSystems.getDefault().getPath(hadoopHome, "etc/hadoop/hdfs-site.xml");
+ conf.addResource(new Path(coreSitePath.toAbsolutePath().toString()));
+ conf.addResource(new Path(hdfsSitePath.toAbsolutePath().toString()));
+ }
+ return conf;
+ }
+
+ public void reset() throws IOException {
+ this.currentIndex = -1;
+ this.closeFileStream();
+ }
+
+ private void closeFileStream() {
+ IOUtils.closeStream(fileStream);
+ }
+
+ public InputStream getNextInputStream() {
+ this.closeFileStream();
+ if (this.currentIndex >= (this.filePaths.size() - 1))
+ return null;
+
+ this.currentIndex++;
+ String filePath = this.filePaths.get(currentIndex);
+
+ Path hdfsPath = new Path(filePath);
+ FileSystem fs;
+ try {
+ fs = FileSystem.get(config);
+ fileStream = fs.open(hdfsPath);
+ } catch (IOException ioe) {
+ this.closeFileStream();
+ throw new RuntimeException("Failed opening file:" + filePath, ioe);
+ }
+
+ return fileStream;
+ }
+
+ public InputStream getCurrentInputStream() {
+ return fileStream;
+ }
+
+ protected int getFilePathListSize() {
+ if (filePaths != null)
+ return filePaths.size();
+ return 0;
+ }
+
+ protected String getFilePathAt(int index) {
+ if (filePaths != null && filePaths.size() > index)
+ return filePaths.get(index);
+ return null;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/LocalFileStreamSource.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/LocalFileStreamSource.java
index c0ab44f..e4ceb70 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/LocalFileStreamSource.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/streams/fs/LocalFileStreamSource.java
@@ -31,101 +31,103 @@
/**
* Source for FileStream for local files
+ *
* @author Casey
*/
public class LocalFileStreamSource implements FileStreamSource {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 3986511547525870698L;
-
- private transient InputStream fileStream;
- private List<String> filePaths;
- private int currentIndex;
-
- public LocalFileStreamSource(){
- this.currentIndex = -1;
- }
-
- public void init(String path, String ext) {
- this.filePaths = new ArrayList<String>();
- File fileAtPath = new File(path);
- if (fileAtPath.isDirectory()) {
- File[] filesInDir = fileAtPath.listFiles(new FileExtensionFilter(ext));
- for (int i=0; i<filesInDir.length; i++) {
- filePaths.add(filesInDir[i].getAbsolutePath());
- }
- }
- else {
- this.filePaths.add(path);
- }
- this.currentIndex = -1;
- }
-
- public void reset() throws IOException {
- this.currentIndex = -1;
- this.closeFileStream();
- }
-
- private void closeFileStream() {
- if (fileStream != null) {
- try {
- fileStream.close();
- } catch (IOException ioe) {
- ioe.printStackTrace();
- }
- }
- }
+ private static final long serialVersionUID = 3986511547525870698L;
- public InputStream getNextInputStream() {
- this.closeFileStream();
-
- if (this.currentIndex >= (this.filePaths.size()-1)) return null;
-
- this.currentIndex++;
- String filePath = this.filePaths.get(currentIndex);
-
- File file = new File(filePath);
- try {
- fileStream = new FileInputStream(file);
- }
- catch(IOException ioe) {
- this.closeFileStream();
- throw new RuntimeException("Failed opening file:"+filePath,ioe);
- }
-
- return fileStream;
- }
+ private transient InputStream fileStream;
+ private List<String> filePaths;
+ private int currentIndex;
- public InputStream getCurrentInputStream() {
- return fileStream;
- }
-
- protected int getFilePathListSize() {
- if (filePaths != null)
- return filePaths.size();
- return 0;
- }
-
- protected String getFilePathAt(int index) {
- if (filePaths != null && filePaths.size() > index)
- return filePaths.get(index);
- return null;
- }
-
- private class FileExtensionFilter implements FilenameFilter {
- private String extension;
- FileExtensionFilter(String ext) {
- extension = ext;
- }
-
- @Override
- public boolean accept(File dir, String name) {
- File f = new File(dir,name);
- if (extension == null)
- return f.isFile();
- else
- return f.isFile() && name.toLowerCase().endsWith("."+extension);
- }
- }
+ public LocalFileStreamSource() {
+ this.currentIndex = -1;
+ }
+
+ public void init(String path, String ext) {
+ this.filePaths = new ArrayList<String>();
+ File fileAtPath = new File(path);
+ if (fileAtPath.isDirectory()) {
+ File[] filesInDir = fileAtPath.listFiles(new FileExtensionFilter(ext));
+ for (int i = 0; i < filesInDir.length; i++) {
+ filePaths.add(filesInDir[i].getAbsolutePath());
+ }
+ }
+ else {
+ this.filePaths.add(path);
+ }
+ this.currentIndex = -1;
+ }
+
+ public void reset() throws IOException {
+ this.currentIndex = -1;
+ this.closeFileStream();
+ }
+
+ private void closeFileStream() {
+ if (fileStream != null) {
+ try {
+ fileStream.close();
+ } catch (IOException ioe) {
+ ioe.printStackTrace();
+ }
+ }
+ }
+
+ public InputStream getNextInputStream() {
+ this.closeFileStream();
+
+ if (this.currentIndex >= (this.filePaths.size() - 1))
+ return null;
+
+ this.currentIndex++;
+ String filePath = this.filePaths.get(currentIndex);
+
+ File file = new File(filePath);
+ try {
+ fileStream = new FileInputStream(file);
+ } catch (IOException ioe) {
+ this.closeFileStream();
+ throw new RuntimeException("Failed opening file:" + filePath, ioe);
+ }
+
+ return fileStream;
+ }
+
+ public InputStream getCurrentInputStream() {
+ return fileStream;
+ }
+
+ protected int getFilePathListSize() {
+ if (filePaths != null)
+ return filePaths.size();
+ return 0;
+ }
+
+ protected String getFilePathAt(int index) {
+ if (filePaths != null && filePaths.size() > index)
+ return filePaths.get(index);
+ return null;
+ }
+
+ private class FileExtensionFilter implements FilenameFilter {
+ private String extension;
+
+ FileExtensionFilter(String ext) {
+ extension = ext;
+ }
+
+ @Override
+ public boolean accept(File dir, String name) {
+ File f = new File(dir, name);
+ if (extension == null)
+ return f.isFile();
+ else
+ return f.isFile() && name.toLowerCase().endsWith("." + extension);
+ }
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/ClusteringEvaluation.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/ClusteringEvaluation.java
index 4af3764..c62de1e 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/ClusteringEvaluation.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/ClusteringEvaluation.java
@@ -50,125 +50,138 @@
*/
public class ClusteringEvaluation implements Task, Configurable {
- private static final long serialVersionUID = -8246537378371580550L;
+ private static final long serialVersionUID = -8246537378371580550L;
- private static final int DISTRIBUTOR_PARALLELISM = 1;
+ private static final int DISTRIBUTOR_PARALLELISM = 1;
- private static final Logger logger = LoggerFactory.getLogger(ClusteringEvaluation.class);
+ private static final Logger logger = LoggerFactory.getLogger(ClusteringEvaluation.class);
- public ClassOption learnerOption = new ClassOption("learner", 'l', "Clustering to run.", Learner.class, DistributedClusterer.class.getName());
+ public ClassOption learnerOption = new ClassOption("learner", 'l', "Clustering to run.", Learner.class,
+ DistributedClusterer.class.getName());
- public ClassOption streamTrainOption = new ClassOption("streamTrain", 's', "Input stream.", InstanceStream.class,
- RandomRBFGeneratorEvents.class.getName());
+ public ClassOption streamTrainOption = new ClassOption("streamTrain", 's', "Input stream.", InstanceStream.class,
+ RandomRBFGeneratorEvents.class.getName());
- public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i', "Maximum number of instances to test/train on (-1 = no limit).", 100000, -1,
- Integer.MAX_VALUE);
+ public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i',
+ "Maximum number of instances to test/train on (-1 = no limit).", 100000, -1,
+ Integer.MAX_VALUE);
- public IntOption measureCollectionTypeOption = new IntOption("measureCollectionType", 'm', "Type of measure collection", 0, 0, Integer.MAX_VALUE);
+ public IntOption measureCollectionTypeOption = new IntOption("measureCollectionType", 'm',
+ "Type of measure collection", 0, 0, Integer.MAX_VALUE);
- public IntOption timeLimitOption = new IntOption("timeLimit", 't', "Maximum number of seconds to test/train for (-1 = no limit).", -1, -1,
- Integer.MAX_VALUE);
+ public IntOption timeLimitOption = new IntOption("timeLimit", 't',
+ "Maximum number of seconds to test/train for (-1 = no limit).", -1, -1,
+ Integer.MAX_VALUE);
- public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", 'f', "How many instances between samples of the learning performance.", 1000, 0,
- Integer.MAX_VALUE);
+ public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", 'f',
+ "How many instances between samples of the learning performance.", 1000, 0,
+ Integer.MAX_VALUE);
- public StringOption evaluationNameOption = new StringOption("evaluationName", 'n', "Identifier of the evaluation", "Clustering__"
- + new SimpleDateFormat("yyyyMMddHHmmss").format(new Date()));
+ public StringOption evaluationNameOption = new StringOption("evaluationName", 'n', "Identifier of the evaluation",
+ "Clustering__"
+ + new SimpleDateFormat("yyyyMMddHHmmss").format(new Date()));
- public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to", null, "csv", true);
+ public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to",
+ null, "csv", true);
- public FloatOption samplingThresholdOption = new FloatOption("samplingThreshold", 'a', "Ratio of instances sampled that will be used for evaluation.", 0.5,
- 0.0, 1.0);
+ public FloatOption samplingThresholdOption = new FloatOption("samplingThreshold", 'a',
+ "Ratio of instances sampled that will be used for evaluation.", 0.5,
+ 0.0, 1.0);
- private ClusteringEntranceProcessor source;
- private InstanceStream streamTrain;
- private ClusteringDistributorProcessor distributor;
- private Stream distributorStream;
- private Stream evaluationStream;
-
- // Default=0: no delay/waiting
- public IntOption sourceDelayOption = new IntOption("sourceDelay", 'w', "How many miliseconds between injections of two instances.", 0, 0, Integer.MAX_VALUE);
-
- private Learner learner;
- private ClusteringEvaluatorProcessor evaluator;
+ private ClusteringEntranceProcessor source;
+ private InstanceStream streamTrain;
+ private ClusteringDistributorProcessor distributor;
+ private Stream distributorStream;
+ private Stream evaluationStream;
- private Topology topology;
- private TopologyBuilder builder;
+ // Default=0: no delay/waiting
+ public IntOption sourceDelayOption = new IntOption("sourceDelay", 'w',
+ "How many miliseconds between injections of two instances.", 0, 0, Integer.MAX_VALUE);
- public void getDescription(StringBuilder sb) {
- sb.append("Clustering evaluation");
+ private Learner learner;
+ private ClusteringEvaluatorProcessor evaluator;
+
+ private Topology topology;
+ private TopologyBuilder builder;
+
+ public void getDescription(StringBuilder sb) {
+ sb.append("Clustering evaluation");
+ }
+
+ @Override
+ public void init() {
+ // TODO remove the if statement theoretically, dynamic binding will work
+ // here! for now, the if statement is used by Storm
+
+ if (builder == null) {
+ logger.warn("Builder was not initialized, initializing it from the Task");
+
+ builder = new TopologyBuilder();
+ logger.debug("Successfully instantiating TopologyBuilder");
+
+ builder.initTopology(evaluationNameOption.getValue(), sourceDelayOption.getValue());
+ logger.debug("Successfully initializing SAMOA topology with name {}", evaluationNameOption.getValue());
}
- @Override
- public void init() {
- // TODO remove the if statement theoretically, dynamic binding will work here! for now, the if statement is used by Storm
+ // instantiate ClusteringEntranceProcessor and its output stream
+ // (sourceStream)
+ source = new ClusteringEntranceProcessor();
+ streamTrain = this.streamTrainOption.getValue();
+ source.setStreamSource(streamTrain);
+ builder.addEntranceProcessor(source);
+ source.setSamplingThreshold(samplingThresholdOption.getValue());
+ source.setMaxNumInstances(instanceLimitOption.getValue());
+ logger.debug("Successfully instantiated ClusteringEntranceProcessor");
- if (builder == null) {
- logger.warn("Builder was not initialized, initializing it from the Task");
+ Stream sourceStream = builder.createStream(source);
+ // starter.setInputStream(sourcePiOutputStream); // FIXME set stream in the
+ // EntrancePI
- builder = new TopologyBuilder();
- logger.debug("Successfully instantiating TopologyBuilder");
+ // distribution of instances and sampling for evaluation
+ distributor = new ClusteringDistributorProcessor();
+ builder.addProcessor(distributor, DISTRIBUTOR_PARALLELISM);
+ builder.connectInputShuffleStream(sourceStream, distributor);
+ distributorStream = builder.createStream(distributor);
+ distributor.setOutputStream(distributorStream);
+ evaluationStream = builder.createStream(distributor);
+ distributor.setEvaluationStream(evaluationStream); // passes evaluation
+ // events along
+ logger.debug("Successfully instantiated Distributor");
- builder.initTopology(evaluationNameOption.getValue(), sourceDelayOption.getValue());
- logger.debug("Successfully initializing SAMOA topology with name {}", evaluationNameOption.getValue());
- }
+ // instantiate learner and connect it to distributorStream
+ learner = this.learnerOption.getValue();
+ learner.init(builder, source.getDataset(), 1);
+ builder.connectInputShuffleStream(distributorStream, learner.getInputProcessor());
+ logger.debug("Successfully instantiated Learner");
- // instantiate ClusteringEntranceProcessor and its output stream (sourceStream)
- source = new ClusteringEntranceProcessor();
- streamTrain = this.streamTrainOption.getValue();
- source.setStreamSource(streamTrain);
- builder.addEntranceProcessor(source);
- source.setSamplingThreshold(samplingThresholdOption.getValue());
- source.setMaxNumInstances(instanceLimitOption.getValue());
- logger.debug("Successfully instantiated ClusteringEntranceProcessor");
-
- Stream sourceStream = builder.createStream(source);
- // starter.setInputStream(sourcePiOutputStream); // FIXME set stream in the EntrancePI
-
- // distribution of instances and sampling for evaluation
- distributor = new ClusteringDistributorProcessor();
- builder.addProcessor(distributor, DISTRIBUTOR_PARALLELISM);
- builder.connectInputShuffleStream(sourceStream, distributor);
- distributorStream = builder.createStream(distributor);
- distributor.setOutputStream(distributorStream);
- evaluationStream = builder.createStream(distributor);
- distributor.setEvaluationStream(evaluationStream); // passes evaluation events along
- logger.debug("Successfully instantiated Distributor");
-
- // instantiate learner and connect it to distributorStream
- learner = this.learnerOption.getValue();
- learner.init(builder, source.getDataset(), 1);
- builder.connectInputShuffleStream(distributorStream, learner.getInputProcessor());
- logger.debug("Successfully instantiated Learner");
-
- evaluator = new ClusteringEvaluatorProcessor.Builder(
+ evaluator = new ClusteringEvaluatorProcessor.Builder(
sampleFrequencyOption.getValue()).dumpFile(dumpFileOption.getFile())
- .decayHorizon(((ClusteringStream) streamTrain).getDecayHorizon()).build();
+ .decayHorizon(((ClusteringStream) streamTrain).getDecayHorizon()).build();
- builder.addProcessor(evaluator);
- for (Stream evaluatorPiInputStream:learner.getResultStreams()) {
- builder.connectInputShuffleStream(evaluatorPiInputStream, evaluator);
- }
- builder.connectInputAllStream(evaluationStream, evaluator);
- logger.debug("Successfully instantiated EvaluatorProcessor");
-
- topology = builder.build();
- logger.debug("Successfully built the topology");
+ builder.addProcessor(evaluator);
+ for (Stream evaluatorPiInputStream : learner.getResultStreams()) {
+ builder.connectInputShuffleStream(evaluatorPiInputStream, evaluator);
}
+ builder.connectInputAllStream(evaluationStream, evaluator);
+ logger.debug("Successfully instantiated EvaluatorProcessor");
- @Override
- public void setFactory(ComponentFactory factory) {
- // TODO unify this code with init() for now, it's used by S4 App
- // dynamic binding theoretically will solve this problem
- builder = new TopologyBuilder(factory);
- logger.debug("Successfully instantiated TopologyBuilder");
+ topology = builder.build();
+ logger.debug("Successfully built the topology");
+ }
- builder.initTopology(evaluationNameOption.getValue());
- logger.debug("Successfully initialized SAMOA topology with name {}", evaluationNameOption.getValue());
+ @Override
+ public void setFactory(ComponentFactory factory) {
+ // TODO unify this code with init() for now, it's used by S4 App
+ // dynamic binding theoretically will solve this problem
+ builder = new TopologyBuilder(factory);
+ logger.debug("Successfully instantiated TopologyBuilder");
- }
+ builder.initTopology(evaluationNameOption.getValue());
+ logger.debug("Successfully initialized SAMOA topology with name {}", evaluationNameOption.getValue());
- public Topology getTopology() {
- return topology;
- }
+ }
+
+ public Topology getTopology() {
+ return topology;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/PrequentialEvaluation.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/PrequentialEvaluation.java
index 70c44a1..081d123 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/PrequentialEvaluation.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/PrequentialEvaluation.java
@@ -49,158 +49,174 @@
import com.yahoo.labs.samoa.topology.TopologyBuilder;
/**
- * Prequential Evaluation task is a scheme in evaluating performance of online classifiers which uses each instance for testing online classifiers model and
- * then it further uses the same instance for training the model(Test-then-train)
+ * Prequential Evaluation task is a scheme in evaluating performance of online
+ * classifiers which uses each instance for testing online classifiers model and
+ * then it further uses the same instance for training the
+ * model(Test-then-train)
*
* @author Arinto Murdopo
*
*/
public class PrequentialEvaluation implements Task, Configurable {
- private static final long serialVersionUID = -8246537378371580550L;
+ private static final long serialVersionUID = -8246537378371580550L;
- private static Logger logger = LoggerFactory.getLogger(PrequentialEvaluation.class);
+ private static Logger logger = LoggerFactory.getLogger(PrequentialEvaluation.class);
- public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName());
+ public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", Learner.class,
+ VerticalHoeffdingTree.class.getName());
- public ClassOption streamTrainOption = new ClassOption("trainStream", 's', "Stream to learn from.", InstanceStream.class,
- RandomTreeGenerator.class.getName());
+ public ClassOption streamTrainOption = new ClassOption("trainStream", 's', "Stream to learn from.",
+ InstanceStream.class,
+ RandomTreeGenerator.class.getName());
- public ClassOption evaluatorOption = new ClassOption("evaluator", 'e', "Classification performance evaluation method.",
- PerformanceEvaluator.class, BasicClassificationPerformanceEvaluator.class.getName());
+ public ClassOption evaluatorOption = new ClassOption("evaluator", 'e',
+ "Classification performance evaluation method.",
+ PerformanceEvaluator.class, BasicClassificationPerformanceEvaluator.class.getName());
- public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i', "Maximum number of instances to test/train on (-1 = no limit).", 1000000, -1,
- Integer.MAX_VALUE);
+ public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i',
+ "Maximum number of instances to test/train on (-1 = no limit).", 1000000, -1,
+ Integer.MAX_VALUE);
- public IntOption timeLimitOption = new IntOption("timeLimit", 't', "Maximum number of seconds to test/train for (-1 = no limit).", -1, -1,
- Integer.MAX_VALUE);
+ public IntOption timeLimitOption = new IntOption("timeLimit", 't',
+ "Maximum number of seconds to test/train for (-1 = no limit).", -1, -1,
+ Integer.MAX_VALUE);
- public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", 'f', "How many instances between samples of the learning performance.", 100000,
- 0, Integer.MAX_VALUE);
+ public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", 'f',
+ "How many instances between samples of the learning performance.", 100000,
+ 0, Integer.MAX_VALUE);
- public StringOption evaluationNameOption = new StringOption("evaluationName", 'n', "Identifier of the evaluation", "Prequential_"
- + new SimpleDateFormat("yyyyMMddHHmmss").format(new Date()));
+ public StringOption evaluationNameOption = new StringOption("evaluationName", 'n', "Identifier of the evaluation",
+ "Prequential_"
+ + new SimpleDateFormat("yyyyMMddHHmmss").format(new Date()));
- public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to", null, "csv", true);
+ public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to",
+ null, "csv", true);
- // Default=0: no delay/waiting
- public IntOption sourceDelayOption = new IntOption("sourceDelay", 'w', "How many microseconds between injections of two instances.", 0, 0, Integer.MAX_VALUE);
- // Batch size to delay the incoming stream: delay of x milliseconds after each batch
- public IntOption batchDelayOption = new IntOption("delayBatchSize", 'b', "The delay batch size: delay of x milliseconds after each batch ", 1, 1, Integer.MAX_VALUE);
-
- private PrequentialSourceProcessor preqSource;
+ // Default=0: no delay/waiting
+ public IntOption sourceDelayOption = new IntOption("sourceDelay", 'w',
+ "How many microseconds between injections of two instances.", 0, 0, Integer.MAX_VALUE);
+ // Batch size to delay the incoming stream: delay of x milliseconds after each
+ // batch
+ public IntOption batchDelayOption = new IntOption("delayBatchSize", 'b',
+ "The delay batch size: delay of x milliseconds after each batch ", 1, 1, Integer.MAX_VALUE);
- // private PrequentialSourceTopologyStarter preqStarter;
+ private PrequentialSourceProcessor preqSource;
- // private EntranceProcessingItem sourcePi;
+ // private PrequentialSourceTopologyStarter preqStarter;
- private Stream sourcePiOutputStream;
+ // private EntranceProcessingItem sourcePi;
- private Learner classifier;
+ private Stream sourcePiOutputStream;
- private EvaluatorProcessor evaluator;
+ private Learner classifier;
- // private ProcessingItem evaluatorPi;
+ private EvaluatorProcessor evaluator;
- // private Stream evaluatorPiInputStream;
+ // private ProcessingItem evaluatorPi;
- private Topology prequentialTopology;
+ // private Stream evaluatorPiInputStream;
- private TopologyBuilder builder;
+ private Topology prequentialTopology;
- public void getDescription(StringBuilder sb, int indent) {
- sb.append("Prequential evaluation");
+ private TopologyBuilder builder;
+
+ public void getDescription(StringBuilder sb, int indent) {
+ sb.append("Prequential evaluation");
+ }
+
+ @Override
+ public void init() {
+ // TODO remove the if statement
+ // theoretically, dynamic binding will work here!
+ // test later!
+ // for now, the if statement is used by Storm
+
+ if (builder == null) {
+ builder = new TopologyBuilder();
+ logger.debug("Successfully instantiating TopologyBuilder");
+
+ builder.initTopology(evaluationNameOption.getValue());
+ logger.debug("Successfully initializing SAMOA topology with name {}", evaluationNameOption.getValue());
}
- @Override
- public void init() {
- // TODO remove the if statement
- // theoretically, dynamic binding will work here!
- // test later!
- // for now, the if statement is used by Storm
+ // instantiate PrequentialSourceProcessor and its output stream
+ // (sourcePiOutputStream)
+ preqSource = new PrequentialSourceProcessor();
+ preqSource.setStreamSource((InstanceStream) this.streamTrainOption.getValue());
+ preqSource.setMaxNumInstances(instanceLimitOption.getValue());
+ preqSource.setSourceDelay(sourceDelayOption.getValue());
+ preqSource.setDelayBatchSize(batchDelayOption.getValue());
+ builder.addEntranceProcessor(preqSource);
+ logger.debug("Successfully instantiating PrequentialSourceProcessor");
- if (builder == null) {
- builder = new TopologyBuilder();
- logger.debug("Successfully instantiating TopologyBuilder");
+ // preqStarter = new PrequentialSourceTopologyStarter(preqSource,
+ // instanceLimitOption.getValue());
+ // sourcePi = builder.createEntrancePi(preqSource, preqStarter);
+ // sourcePiOutputStream = builder.createStream(sourcePi);
- builder.initTopology(evaluationNameOption.getValue());
- logger.debug("Successfully initializing SAMOA topology with name {}", evaluationNameOption.getValue());
- }
+ sourcePiOutputStream = builder.createStream(preqSource);
+ // preqStarter.setInputStream(sourcePiOutputStream);
- // instantiate PrequentialSourceProcessor and its output stream (sourcePiOutputStream)
- preqSource = new PrequentialSourceProcessor();
- preqSource.setStreamSource((InstanceStream) this.streamTrainOption.getValue());
- preqSource.setMaxNumInstances(instanceLimitOption.getValue());
- preqSource.setSourceDelay(sourceDelayOption.getValue());
- preqSource.setDelayBatchSize(batchDelayOption.getValue());
- builder.addEntranceProcessor(preqSource);
- logger.debug("Successfully instantiating PrequentialSourceProcessor");
+ // instantiate classifier and connect it to sourcePiOutputStream
+ classifier = this.learnerOption.getValue();
+ classifier.init(builder, preqSource.getDataset(), 1);
+ builder.connectInputShuffleStream(sourcePiOutputStream, classifier.getInputProcessor());
+ logger.debug("Successfully instantiating Classifier");
- // preqStarter = new PrequentialSourceTopologyStarter(preqSource, instanceLimitOption.getValue());
- // sourcePi = builder.createEntrancePi(preqSource, preqStarter);
- // sourcePiOutputStream = builder.createStream(sourcePi);
+ PerformanceEvaluator evaluatorOptionValue = this.evaluatorOption.getValue();
+ if (!PrequentialEvaluation.isLearnerAndEvaluatorCompatible(classifier, evaluatorOptionValue)) {
+ evaluatorOptionValue = getDefaultPerformanceEvaluatorForLearner(classifier);
+ }
+ evaluator = new EvaluatorProcessor.Builder(evaluatorOptionValue)
+ .samplingFrequency(sampleFrequencyOption.getValue()).dumpFile(dumpFileOption.getFile()).build();
- sourcePiOutputStream = builder.createStream(preqSource);
- // preqStarter.setInputStream(sourcePiOutputStream);
-
- // instantiate classifier and connect it to sourcePiOutputStream
- classifier = this.learnerOption.getValue();
- classifier.init(builder, preqSource.getDataset(), 1);
- builder.connectInputShuffleStream(sourcePiOutputStream, classifier.getInputProcessor());
- logger.debug("Successfully instantiating Classifier");
-
- PerformanceEvaluator evaluatorOptionValue = this.evaluatorOption.getValue();
- if (!PrequentialEvaluation.isLearnerAndEvaluatorCompatible(classifier, evaluatorOptionValue)) {
- evaluatorOptionValue = getDefaultPerformanceEvaluatorForLearner(classifier);
- }
- evaluator = new EvaluatorProcessor.Builder(evaluatorOptionValue)
- .samplingFrequency(sampleFrequencyOption.getValue()).dumpFile(dumpFileOption.getFile()).build();
-
- // evaluatorPi = builder.createPi(evaluator);
- // evaluatorPi.connectInputShuffleStream(evaluatorPiInputStream);
- builder.addProcessor(evaluator);
- for (Stream evaluatorPiInputStream:classifier.getResultStreams()) {
- builder.connectInputShuffleStream(evaluatorPiInputStream, evaluator);
- }
-
- logger.debug("Successfully instantiating EvaluatorProcessor");
-
- prequentialTopology = builder.build();
- logger.debug("Successfully building the topology");
+ // evaluatorPi = builder.createPi(evaluator);
+ // evaluatorPi.connectInputShuffleStream(evaluatorPiInputStream);
+ builder.addProcessor(evaluator);
+ for (Stream evaluatorPiInputStream : classifier.getResultStreams()) {
+ builder.connectInputShuffleStream(evaluatorPiInputStream, evaluator);
}
- @Override
- public void setFactory(ComponentFactory factory) {
- // TODO unify this code with init()
- // for now, it's used by S4 App
- // dynamic binding theoretically will solve this problem
- builder = new TopologyBuilder(factory);
- logger.debug("Successfully instantiating TopologyBuilder");
+ logger.debug("Successfully instantiating EvaluatorProcessor");
- builder.initTopology(evaluationNameOption.getValue());
- logger.debug("Successfully initializing SAMOA topology with name {}", evaluationNameOption.getValue());
+ prequentialTopology = builder.build();
+ logger.debug("Successfully building the topology");
+ }
- }
+ @Override
+ public void setFactory(ComponentFactory factory) {
+ // TODO unify this code with init()
+ // for now, it's used by S4 App
+ // dynamic binding theoretically will solve this problem
+ builder = new TopologyBuilder(factory);
+ logger.debug("Successfully instantiating TopologyBuilder");
- public Topology getTopology() {
- return prequentialTopology;
+ builder.initTopology(evaluationNameOption.getValue());
+ logger.debug("Successfully initializing SAMOA topology with name {}", evaluationNameOption.getValue());
+
+ }
+
+ public Topology getTopology() {
+ return prequentialTopology;
+ }
+
+ //
+ // @Override
+ // public TopologyStarter getTopologyStarter() {
+ // return this.preqStarter;
+ // }
+
+ private static boolean isLearnerAndEvaluatorCompatible(Learner learner, PerformanceEvaluator evaluator) {
+ return (learner instanceof RegressionLearner && evaluator instanceof RegressionPerformanceEvaluator) ||
+ (learner instanceof ClassificationLearner && evaluator instanceof ClassificationPerformanceEvaluator);
+ }
+
+ private static PerformanceEvaluator getDefaultPerformanceEvaluatorForLearner(Learner learner) {
+ if (learner instanceof RegressionLearner) {
+ return new BasicRegressionPerformanceEvaluator();
}
- //
- // @Override
- // public TopologyStarter getTopologyStarter() {
- // return this.preqStarter;
- // }
-
- private static boolean isLearnerAndEvaluatorCompatible(Learner learner, PerformanceEvaluator evaluator) {
- return (learner instanceof RegressionLearner && evaluator instanceof RegressionPerformanceEvaluator) ||
- (learner instanceof ClassificationLearner && evaluator instanceof ClassificationPerformanceEvaluator);
- }
-
- private static PerformanceEvaluator getDefaultPerformanceEvaluatorForLearner(Learner learner) {
- if (learner instanceof RegressionLearner) {
- return new BasicRegressionPerformanceEvaluator();
- }
- // Default to BasicClassificationPerformanceEvaluator for all other cases
- return new BasicClassificationPerformanceEvaluator();
- }
+ // Default to BasicClassificationPerformanceEvaluator for all other cases
+ return new BasicClassificationPerformanceEvaluator();
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/Task.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/Task.java
index 41b47e4..5753349 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/Task.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/tasks/Task.java
@@ -28,34 +28,34 @@
*/
public interface Task {
- /**
- * Initialize this SAMOA task,
- * i.e. create and connect ProcessingItems and Streams
- * and initialize the topology
- */
- public void init();
-
- /**
- * Return the final topology object to be executed in the cluster
- * @return topology object to be submitted to be executed in the cluster
- */
- public Topology getTopology();
-
- // /**
- // * Return the entrance processor to start SAMOA topology
- // * The logic to start the topology should be implemented here
- // * @return entrance processor to start the topology
- // */
- // public TopologyStarter getTopologyStarter();
-
- /**
- * Sets the factory.
- * TODO: propose to hide factory from task,
- * i.e. Task will only see TopologyBuilder,
- * and factory creation will be handled by TopologyBuilder
- *
- * @param factory the new factory
- */
- public void setFactory(ComponentFactory factory) ;
-
+ /**
+ * Initialize this SAMOA task, i.e. create and connect ProcessingItems and
+ * Streams and initialize the topology
+ */
+ public void init();
+
+ /**
+ * Return the final topology object to be executed in the cluster
+ *
+ * @return topology object to be submitted to be executed in the cluster
+ */
+ public Topology getTopology();
+
+ // /**
+ // * Return the entrance processor to start SAMOA topology
+ // * The logic to start the topology should be implemented here
+ // * @return entrance processor to start the topology
+ // */
+ // public TopologyStarter getTopologyStarter();
+
+ /**
+ * Sets the factory. TODO: propose to hide factory from task, i.e. Task will
+ * only see TopologyBuilder, and factory creation will be handled by
+ * TopologyBuilder
+ *
+ * @param factory
+ * the new factory
+ */
+ public void setFactory(ComponentFactory factory);
+
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractEntranceProcessingItem.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractEntranceProcessingItem.java
index c0f0cc3..1e3c9b5 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractEntranceProcessingItem.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractEntranceProcessingItem.java
@@ -24,85 +24,93 @@
/**
* Helper class for EntranceProcessingItem implementation.
+ *
* @author Anh Thu Vu
- *
+ *
*/
public abstract class AbstractEntranceProcessingItem implements EntranceProcessingItem {
- private EntranceProcessor processor;
- private String name;
- private Stream outputStream;
-
- /*
- * Constructor
- */
- public AbstractEntranceProcessingItem() {
- this(null);
- }
- public AbstractEntranceProcessingItem(EntranceProcessor processor) {
- this.processor = processor;
- }
-
- /*
- * Processor
- */
- /**
- * Set the entrance processor for this EntranceProcessingItem
- * @param processor
- * the processor
- */
- protected void setProcessor(EntranceProcessor processor) {
- this.processor = processor;
- }
-
- /**
- * Get the EntranceProcessor of this EntranceProcessingItem.
- * @return the EntranceProcessor
- */
- public EntranceProcessor getProcessor() {
- return this.processor;
- }
-
- /*
- * Name/ID
- */
- /**
- * Set the name (or ID) of this EntranceProcessingItem
- * @param name
- */
- public void setName(String name) {
- this.name = name;
- }
-
- /**
- * Get the name (or ID) of this EntranceProcessingItem
- * @return the name (or ID)
- */
- public String getName() {
- return this.name;
- }
-
- /*
- * Output Stream
- */
- /**
- * Set the output stream of this EntranceProcessingItem.
- * An EntranceProcessingItem should have only 1 single output stream and
- * should not be re-assigned.
- * @return this EntranceProcessingItem
- */
- public EntranceProcessingItem setOutputStream(Stream outputStream) {
- if (this.outputStream != null && this.outputStream != outputStream) {
- throw new IllegalStateException("Cannot overwrite output stream of EntranceProcessingItem");
- } else
- this.outputStream = outputStream;
- return this;
- }
-
- /**
- * Get the output stream of this EntranceProcessingItem.
- * @return the output stream
- */
- public Stream getOutputStream() {
- return this.outputStream;
- }
+ private EntranceProcessor processor;
+ private String name;
+ private Stream outputStream;
+
+ /*
+ * Constructor
+ */
+ public AbstractEntranceProcessingItem() {
+ this(null);
+ }
+
+ public AbstractEntranceProcessingItem(EntranceProcessor processor) {
+ this.processor = processor;
+ }
+
+ /*
+ * Processor
+ */
+ /**
+ * Set the entrance processor for this EntranceProcessingItem
+ *
+ * @param processor
+ * the processor
+ */
+ protected void setProcessor(EntranceProcessor processor) {
+ this.processor = processor;
+ }
+
+ /**
+ * Get the EntranceProcessor of this EntranceProcessingItem.
+ *
+ * @return the EntranceProcessor
+ */
+ public EntranceProcessor getProcessor() {
+ return this.processor;
+ }
+
+ /*
+ * Name/ID
+ */
+ /**
+ * Set the name (or ID) of this EntranceProcessingItem
+ *
+ * @param name
+ */
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ /**
+ * Get the name (or ID) of this EntranceProcessingItem
+ *
+ * @return the name (or ID)
+ */
+ public String getName() {
+ return this.name;
+ }
+
+ /*
+ * Output Stream
+ */
+ /**
+ * Set the output stream of this EntranceProcessingItem. An
+ * EntranceProcessingItem should have only 1 single output stream and should
+ * not be re-assigned.
+ *
+ * @return this EntranceProcessingItem
+ */
+ public EntranceProcessingItem setOutputStream(Stream outputStream) {
+ if (this.outputStream != null && this.outputStream != outputStream) {
+ throw new IllegalStateException("Cannot overwrite output stream of EntranceProcessingItem");
+ } else
+ this.outputStream = outputStream;
+ return this;
+ }
+
+ /**
+ * Get the output stream of this EntranceProcessingItem.
+ *
+ * @return the output stream
+ */
+ public Stream getOutputStream() {
+ return this.outputStream;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractProcessingItem.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractProcessingItem.java
index d0f04f7..60d76e0 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractProcessingItem.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractProcessingItem.java
@@ -26,136 +26,145 @@
/**
* Abstract ProcessingItem
*
- * Helper for implementation of ProcessingItem. It has basic information
- * for a ProcessingItem: name, parallelismLevel and a processor.
- * Subclass of this class needs to implement {@link #addInputStream(Stream, PartitioningScheme)}.
+ * Helper for implementation of ProcessingItem. It has basic information for a
+ * ProcessingItem: name, parallelismLevel and a processor. Subclass of this
+ * class needs to implement {@link #addInputStream(Stream, PartitioningScheme)}.
*
* @author Anh Thu Vu
- *
+ *
*/
public abstract class AbstractProcessingItem implements ProcessingItem {
- private String name;
- private int parallelism;
- private Processor processor;
-
- /*
- * Constructor
- */
- public AbstractProcessingItem() {
- this(null);
- }
- public AbstractProcessingItem(Processor processor) {
- this(processor,1);
- }
- public AbstractProcessingItem(Processor processor, int parallelism) {
- this.processor = processor;
- this.parallelism = parallelism;
- }
-
- /*
- * Processor
- */
- /**
- * Set the processor for this ProcessingItem
- * @param processor
- * the processor
- */
- protected void setProcessor(Processor processor) {
- this.processor = processor;
- }
-
- /**
- * Get the processor of this ProcessingItem
- * @return the processor
- */
- public Processor getProcessor() {
- return this.processor;
- }
-
- /*
- * Parallelism
- */
- /**
- * Set the parallelism factor of this ProcessingItem
- * @param parallelism
- */
- protected void setParallelism(int parallelism) {
- this.parallelism = parallelism;
- }
-
- /**
- * Get the parallelism factor of this ProcessingItem
- * @return the parallelism factor
- */
- @Override
- public int getParallelism() {
- return this.parallelism;
- }
-
- /*
- * Name/ID
- */
- /**
- * Set the name (or ID) of this ProcessingItem
- * @param name
- * the name/ID
- */
- public void setName(String name) {
- this.name = name;
- }
-
- /**
- * Get the name (or ID) of this ProcessingItem
- * @return the name/ID
- */
- public String getName() {
- return this.name;
- }
-
- /*
- * Add input streams
- */
- /**
- * Add an input stream to this ProcessingItem
- *
- * @param inputStream
- * the input stream to add
- * @param scheme
- * partitioning scheme associated with this ProcessingItem and the input stream
- * @return this ProcessingItem
- */
- protected abstract ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme);
+ private String name;
+ private int parallelism;
+ private Processor processor;
- /**
- * Add an input stream to this ProcessingItem with SHUFFLE scheme
- *
- * @param inputStream
- * the input stream
- * @return this ProcessingItem
- */
- public ProcessingItem connectInputShuffleStream(Stream inputStream) {
- return this.addInputStream(inputStream, PartitioningScheme.SHUFFLE);
- }
+ /*
+ * Constructor
+ */
+ public AbstractProcessingItem() {
+ this(null);
+ }
- /**
- * Add an input stream to this ProcessingItem with GROUP_BY_KEY scheme
- *
- * @param inputStream
- * the input stream
- * @return this ProcessingItem
- */
- public ProcessingItem connectInputKeyStream(Stream inputStream) {
- return this.addInputStream(inputStream, PartitioningScheme.GROUP_BY_KEY);
- }
+ public AbstractProcessingItem(Processor processor) {
+ this(processor, 1);
+ }
- /**
- * Add an input stream to this ProcessingItem with BROADCAST scheme
- *
- * @param inputStream
- * the input stream
- * @return this ProcessingItem
- */
- public ProcessingItem connectInputAllStream(Stream inputStream) {
- return this.addInputStream(inputStream, PartitioningScheme.BROADCAST);
- }
+ public AbstractProcessingItem(Processor processor, int parallelism) {
+ this.processor = processor;
+ this.parallelism = parallelism;
+ }
+
+ /*
+ * Processor
+ */
+ /**
+ * Set the processor for this ProcessingItem
+ *
+ * @param processor
+ * the processor
+ */
+ protected void setProcessor(Processor processor) {
+ this.processor = processor;
+ }
+
+ /**
+ * Get the processor of this ProcessingItem
+ *
+ * @return the processor
+ */
+ public Processor getProcessor() {
+ return this.processor;
+ }
+
+ /*
+ * Parallelism
+ */
+ /**
+ * Set the parallelism factor of this ProcessingItem
+ *
+ * @param parallelism
+ */
+ protected void setParallelism(int parallelism) {
+ this.parallelism = parallelism;
+ }
+
+ /**
+ * Get the parallelism factor of this ProcessingItem
+ *
+ * @return the parallelism factor
+ */
+ @Override
+ public int getParallelism() {
+ return this.parallelism;
+ }
+
+ /*
+ * Name/ID
+ */
+ /**
+ * Set the name (or ID) of this ProcessingItem
+ *
+ * @param name
+ * the name/ID
+ */
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ /**
+ * Get the name (or ID) of this ProcessingItem
+ *
+ * @return the name/ID
+ */
+ public String getName() {
+ return this.name;
+ }
+
+ /*
+ * Add input streams
+ */
+ /**
+ * Add an input stream to this ProcessingItem
+ *
+ * @param inputStream
+ * the input stream to add
+ * @param scheme
+ * partitioning scheme associated with this ProcessingItem and the
+ * input stream
+ * @return this ProcessingItem
+ */
+ protected abstract ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme);
+
+ /**
+ * Add an input stream to this ProcessingItem with SHUFFLE scheme
+ *
+ * @param inputStream
+ * the input stream
+ * @return this ProcessingItem
+ */
+ public ProcessingItem connectInputShuffleStream(Stream inputStream) {
+ return this.addInputStream(inputStream, PartitioningScheme.SHUFFLE);
+ }
+
+ /**
+ * Add an input stream to this ProcessingItem with GROUP_BY_KEY scheme
+ *
+ * @param inputStream
+ * the input stream
+ * @return this ProcessingItem
+ */
+ public ProcessingItem connectInputKeyStream(Stream inputStream) {
+ return this.addInputStream(inputStream, PartitioningScheme.GROUP_BY_KEY);
+ }
+
+ /**
+ * Add an input stream to this ProcessingItem with BROADCAST scheme
+ *
+ * @param inputStream
+ * the input stream
+ * @return this ProcessingItem
+ */
+ public ProcessingItem connectInputAllStream(Stream inputStream) {
+ return this.addInputStream(inputStream, PartitioningScheme.BROADCAST);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractStream.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractStream.java
index b3544ed..a16d566 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractStream.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractStream.java
@@ -25,91 +25,95 @@
/**
* Abstract Stream
*
- * Helper for implementation of Stream. It has basic information
- * for a Stream: streamID and source ProcessingItem.
- * Subclass of this class needs to implement {@link #put(ContentEvent)}.
+ * Helper for implementation of Stream. It has basic information for a Stream:
+ * streamID and source ProcessingItem. Subclass of this class needs to implement
+ * {@link #put(ContentEvent)}.
*
* @author Anh Thu Vu
- *
+ *
*/
public abstract class AbstractStream implements Stream {
- private String streamID;
- private IProcessingItem sourcePi;
- private int batchSize;
-
- /*
- * Constructor
- */
- public AbstractStream() {
- this(null);
- }
- public AbstractStream(IProcessingItem sourcePi) {
- this.sourcePi = sourcePi;
- this.batchSize = 1;
- }
-
- /**
- * Get source processing item of this stream
- * @return
- */
- public IProcessingItem getSourceProcessingItem() {
- return this.sourcePi;
- }
+ private String streamID;
+ private IProcessingItem sourcePi;
+ private int batchSize;
- /*
- * Process event
- */
- @Override
- /**
- * Send a ContentEvent
- * @param event
- * the ContentEvent to be sent
- */
- public abstract void put(ContentEvent event);
+ /*
+ * Constructor
+ */
+ public AbstractStream() {
+ this(null);
+ }
- /*
- * Stream name
- */
- /**
- * Get name (ID) of this stream
- * @return the name (ID)
- */
- @Override
- public String getStreamId() {
- return this.streamID;
- }
-
- /**
- * Set the name (ID) of this stream
- * @param streamID
- * the name (ID)
- */
- public void setStreamId (String streamID) {
- this.streamID = streamID;
- }
-
- /*
- * Batch size
- */
- /**
- * Set suggested batch size
- *
- * @param batchSize
- * the suggested batch size
- *
- */
- @Override
- public void setBatchSize(int batchSize) {
- this.batchSize = batchSize;
- }
+ public AbstractStream(IProcessingItem sourcePi) {
+ this.sourcePi = sourcePi;
+ this.batchSize = 1;
+ }
- /**
- * Get suggested batch size
- *
- * @return the suggested batch size
- */
- public int getBatchSize() {
- return this.batchSize;
- }
+ /**
+ * Get source processing item of this stream
+ *
+ * @return
+ */
+ public IProcessingItem getSourceProcessingItem() {
+ return this.sourcePi;
+ }
+
+ /*
+ * Process event
+ */
+ @Override
+ /**
+ * Send a ContentEvent
+ * @param event
+ * the ContentEvent to be sent
+ */
+ public abstract void put(ContentEvent event);
+
+ /*
+ * Stream name
+ */
+ /**
+ * Get name (ID) of this stream
+ *
+ * @return the name (ID)
+ */
+ @Override
+ public String getStreamId() {
+ return this.streamID;
+ }
+
+ /**
+ * Set the name (ID) of this stream
+ *
+ * @param streamID
+ * the name (ID)
+ */
+ public void setStreamId(String streamID) {
+ this.streamID = streamID;
+ }
+
+ /*
+ * Batch size
+ */
+ /**
+ * Set suggested batch size
+ *
+ * @param batchSize
+ * the suggested batch size
+ *
+ */
+ @Override
+ public void setBatchSize(int batchSize) {
+ this.batchSize = batchSize;
+ }
+
+ /**
+ * Get suggested batch size
+ *
+ * @return the suggested batch size
+ */
+ public int getBatchSize() {
+ return this.batchSize;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractTopology.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractTopology.java
index 53385b1..00096ca 100755
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractTopology.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/AbstractTopology.java
@@ -26,108 +26,109 @@
/**
* Topology abstract class.
*
- * It manages basic information of a topology: name, sets of Streams and ProcessingItems
+ * It manages basic information of a topology: name, sets of Streams and
+ * ProcessingItems
*
*/
public abstract class AbstractTopology implements Topology {
- private String topoName;
- private Set<Stream> streams;
- private Set<IProcessingItem> processingItems;
- private Set<EntranceProcessingItem> entranceProcessingItems;
+ private String topoName;
+ private Set<Stream> streams;
+ private Set<IProcessingItem> processingItems;
+ private Set<EntranceProcessingItem> entranceProcessingItems;
- protected AbstractTopology(String name) {
- this.topoName = name;
- this.streams = new HashSet<>();
- this.processingItems = new HashSet<>();
- this.entranceProcessingItems = new HashSet<>();
- }
-
- /**
- * Gets the name of this topology
- *
- * @return name of the topology
- */
- public String getTopologyName() {
- return this.topoName;
- }
-
- /**
- * Sets the name of this topology
- *
- * @param topologyName
- * name of the topology
- */
- public void setTopologyName(String topologyName) {
- this.topoName = topologyName;
- }
-
- /**
- * Adds an Entrance processing item to the topology.
- *
- * @param epi
- * Entrance processing item
- */
- public void addEntranceProcessingItem(EntranceProcessingItem epi) {
- this.entranceProcessingItems.add(epi);
- this.addProcessingItem(epi);
- }
-
- /**
- * Gets entrance processing items in the topology
- *
- * @return the set of processing items
- */
- public Set<EntranceProcessingItem> getEntranceProcessingItems() {
- return this.entranceProcessingItems;
- }
+ protected AbstractTopology(String name) {
+ this.topoName = name;
+ this.streams = new HashSet<>();
+ this.processingItems = new HashSet<>();
+ this.entranceProcessingItems = new HashSet<>();
+ }
- /**
- * Add processing item to topology.
- *
- * @param procItem
- * Processing item.
- */
- public void addProcessingItem(IProcessingItem procItem) {
- addProcessingItem(procItem, 1);
- }
+ /**
+ * Gets the name of this topology
+ *
+ * @return name of the topology
+ */
+ public String getTopologyName() {
+ return this.topoName;
+ }
- /**
- * Add processing item to topology.
- *
- * @param procItem
- * Processing item.
- * @param parallelismHint
- * Processing item parallelism level.
- */
- public void addProcessingItem(IProcessingItem procItem, int parallelismHint) {
- this.processingItems.add(procItem);
- }
-
- /**
- * Gets processing items in the topology (including entrance processing items)
- *
- * @return the set of processing items
- */
- public Set<IProcessingItem> getProcessingItems() {
- return this.processingItems;
- }
+ /**
+ * Sets the name of this topology
+ *
+ * @param topologyName
+ * name of the topology
+ */
+ public void setTopologyName(String topologyName) {
+ this.topoName = topologyName;
+ }
- /**
- * Add stream to topology.
- *
- * @param stream
- */
- public void addStream(Stream stream) {
- this.streams.add(stream);
- }
-
- /**
- * Gets streams in the topology
- *
- * @return the set of streams
- */
- public Set<Stream> getStreams() {
- return this.streams;
- }
+ /**
+ * Adds an Entrance processing item to the topology.
+ *
+ * @param epi
+ * Entrance processing item
+ */
+ public void addEntranceProcessingItem(EntranceProcessingItem epi) {
+ this.entranceProcessingItems.add(epi);
+ this.addProcessingItem(epi);
+ }
+
+ /**
+ * Gets entrance processing items in the topology
+ *
+ * @return the set of processing items
+ */
+ public Set<EntranceProcessingItem> getEntranceProcessingItems() {
+ return this.entranceProcessingItems;
+ }
+
+ /**
+ * Add processing item to topology.
+ *
+ * @param procItem
+ * Processing item.
+ */
+ public void addProcessingItem(IProcessingItem procItem) {
+ addProcessingItem(procItem, 1);
+ }
+
+ /**
+ * Add processing item to topology.
+ *
+ * @param procItem
+ * Processing item.
+ * @param parallelismHint
+ * Processing item parallelism level.
+ */
+ public void addProcessingItem(IProcessingItem procItem, int parallelismHint) {
+ this.processingItems.add(procItem);
+ }
+
+ /**
+ * Gets processing items in the topology (including entrance processing items)
+ *
+ * @return the set of processing items
+ */
+ public Set<IProcessingItem> getProcessingItems() {
+ return this.processingItems;
+ }
+
+ /**
+ * Add stream to topology.
+ *
+ * @param stream
+ */
+ public void addStream(Stream stream) {
+ this.streams.add(stream);
+ }
+
+ /**
+ * Gets streams in the topology
+ *
+ * @return the set of streams
+ */
+ public Set<Stream> getStreams() {
+ return this.streams;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ComponentFactory.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ComponentFactory.java
index 433f516..f1e82dc 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ComponentFactory.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ComponentFactory.java
@@ -28,51 +28,55 @@
*/
public interface ComponentFactory {
- /**
- * Creates a platform specific processing item with the specified processor.
- *
- * @param processor
- * contains the logic for this processing item.
- * @return ProcessingItem
- */
- public ProcessingItem createPi(Processor processor);
+ /**
+ * Creates a platform specific processing item with the specified processor.
+ *
+ * @param processor
+ * contains the logic for this processing item.
+ * @return ProcessingItem
+ */
+ public ProcessingItem createPi(Processor processor);
- /**
- * Creates a platform specific processing item with the specified processor. Additionally sets the parallelism level.
- *
- * @param processor
- * contains the logic for this processing item.
- * @param parallelism
- * defines the amount of instances of this processing item will be created.
- * @return ProcessingItem
- */
- public ProcessingItem createPi(Processor processor, int parallelism);
+ /**
+ * Creates a platform specific processing item with the specified processor.
+ * Additionally sets the parallelism level.
+ *
+ * @param processor
+ * contains the logic for this processing item.
+ * @param parallelism
+ * defines the amount of instances of this processing item will be
+ * created.
+ * @return ProcessingItem
+ */
+ public ProcessingItem createPi(Processor processor, int parallelism);
- /**
- * Creates a platform specific processing item with the specified processor that is the entrance point in the topology. This processing item can either
- * generate a stream of data or connect to an external stream of data.
- *
- * @param entranceProcessor
- * contains the logic for this processing item.
- * @return EntranceProcessingItem
- */
- public EntranceProcessingItem createEntrancePi(EntranceProcessor entranceProcessor);
+ /**
+ * Creates a platform specific processing item with the specified processor
+ * that is the entrance point in the topology. This processing item can either
+ * generate a stream of data or connect to an external stream of data.
+ *
+ * @param entranceProcessor
+ * contains the logic for this processing item.
+ * @return EntranceProcessingItem
+ */
+ public EntranceProcessingItem createEntrancePi(EntranceProcessor entranceProcessor);
- /**
- * Creates a platform specific stream.
- *
- * @param sourcePi
- * source processing item which will provide the events for this stream.
- * @return Stream
- */
- public Stream createStream(IProcessingItem sourcePi);
+ /**
+ * Creates a platform specific stream.
+ *
+ * @param sourcePi
+ * source processing item which will provide the events for this
+ * stream.
+ * @return Stream
+ */
+ public Stream createStream(IProcessingItem sourcePi);
- /**
- * Creates a platform specific topology.
- *
- * @param topoName
- * Topology name.
- * @return Topology
- */
- public Topology createTopology(String topoName);
+ /**
+ * Creates a platform specific topology.
+ *
+ * @param topoName
+ * Topology name.
+ * @return Topology
+ */
+ public Topology createTopology(String topoName);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/EntranceProcessingItem.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/EntranceProcessingItem.java
index 32ed109..9698cc3 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/EntranceProcessingItem.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/EntranceProcessingItem.java
@@ -27,20 +27,21 @@
*/
public interface EntranceProcessingItem extends IProcessingItem {
- @Override
- /**
- * Gets the processing item processor.
- *
- * @return the embedded EntranceProcessor.
- */
- public EntranceProcessor getProcessor();
+ @Override
+ /**
+ * Gets the processing item processor.
+ *
+ * @return the embedded EntranceProcessor.
+ */
+ public EntranceProcessor getProcessor();
- /**
- * Set the single output stream for this EntranceProcessingItem.
- *
- * @param stream
- * the stream
- * @return the current instance of the EntranceProcessingItem for fluent interface.
- */
- public EntranceProcessingItem setOutputStream(Stream stream);
+ /**
+ * Set the single output stream for this EntranceProcessingItem.
+ *
+ * @param stream
+ * the stream
+ * @return the current instance of the EntranceProcessingItem for fluent
+ * interface.
+ */
+ public EntranceProcessingItem setOutputStream(Stream stream);
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/IProcessingItem.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/IProcessingItem.java
index 7a70dc4..97ea9a4 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/IProcessingItem.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/IProcessingItem.java
@@ -26,22 +26,22 @@
* ProcessingItem interface specific for entrance processing items.
*
* @author severien
- *
+ *
*/
public interface IProcessingItem {
-
- /**
- * Gets the processing item processor.
- *
- * @return Processor
- */
- public Processor getProcessor();
-
- /**
- * Sets processing item name.
- *
- * @param name
- */
- //public void setName(String name);
+
+ /**
+ * Gets the processing item processor.
+ *
+ * @return Processor
+ */
+ public Processor getProcessor();
+
+ /**
+ * Sets processing item name.
+ *
+ * @param name
+ */
+ // public void setName(String name);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ISubmitter.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ISubmitter.java
index 8499f80..10c0690 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ISubmitter.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ISubmitter.java
@@ -23,24 +23,25 @@
import com.yahoo.labs.samoa.tasks.Task;
/**
- * Submitter interface for programatically deploying platform specific topologies.
+ * Submitter interface for programatically deploying platform specific
+ * topologies.
*
* @author severien
- *
+ *
*/
public interface ISubmitter {
- /**
- * Deploy a specific task to a platform.
- *
- * @param task
- */
- public void deployTask(Task task);
-
- /**
- * Sets if the task should run locally or distributed.
- *
- * @param bool
- */
- public void setLocal(boolean bool);
+ /**
+ * Deploy a specific task to a platform.
+ *
+ * @param task
+ */
+ public void deployTask(Task task);
+
+ /**
+ * Sets if the task should run locally or distributed.
+ *
+ * @param bool
+ */
+ public void setLocal(boolean bool);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/LocalEntranceProcessingItem.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/LocalEntranceProcessingItem.java
index 2e8758f..1fe4642 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/LocalEntranceProcessingItem.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/LocalEntranceProcessingItem.java
@@ -24,62 +24,63 @@
import com.yahoo.labs.samoa.core.EntranceProcessor;
/**
- * Implementation of EntranceProcessingItem for local engines (Simple, Multithreads)
+ * Implementation of EntranceProcessingItem for local engines (Simple,
+ * Multithreads)
*
* @author Anh Thu Vu
- *
+ *
*/
public class LocalEntranceProcessingItem extends AbstractEntranceProcessingItem {
- public LocalEntranceProcessingItem(EntranceProcessor processor) {
- super(processor);
- }
-
- /**
- * If there are available events, first event in the queue will be
- * sent out on the output stream.
- * @return true if there is (at least) one available event and it was sent out
- * false otherwise
- */
- public boolean injectNextEvent() {
- if (this.getProcessor().hasNext()) {
- ContentEvent event = this.getProcessor().nextEvent();
- this.getOutputStream().put(event);
- return true;
- }
- return false;
- }
+ public LocalEntranceProcessingItem(EntranceProcessor processor) {
+ super(processor);
+ }
- /**
- * Start sending events by calling {@link #injectNextEvent()}. If there are no available events,
- * and that the stream is not entirely consumed, it will wait by calling
- * {@link #waitForNewEvents()} before attempting to send again.
- * </p>
- * When the stream is entirely consumed, the last event is tagged accordingly and the processor gets the
- * finished status.
- *
- */
- public void startSendingEvents () {
- if (this.getOutputStream() == null)
- throw new IllegalStateException("Try sending events from EntrancePI while outputStream is not set.");
-
- while (!this.getProcessor().isFinished()) {
- if (!this.injectNextEvent()) {
- try {
- waitForNewEvents();
- } catch (Exception e) {
- e.printStackTrace();
- break;
- }
- }
+ /**
+ * If there are available events, first event in the queue will be sent out on
+ * the output stream.
+ *
+ * @return true if there is (at least) one available event and it was sent out
+ * false otherwise
+ */
+ public boolean injectNextEvent() {
+ if (this.getProcessor().hasNext()) {
+ ContentEvent event = this.getProcessor().nextEvent();
+ this.getOutputStream().put(event);
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * Start sending events by calling {@link #injectNextEvent()}. If there are no
+ * available events, and that the stream is not entirely consumed, it will
+ * wait by calling {@link #waitForNewEvents()} before attempting to send
+ * again. </p> When the stream is entirely consumed, the last event is tagged
+ * accordingly and the processor gets the finished status.
+ *
+ */
+ public void startSendingEvents() {
+ if (this.getOutputStream() == null)
+ throw new IllegalStateException("Try sending events from EntrancePI while outputStream is not set.");
+
+ while (!this.getProcessor().isFinished()) {
+ if (!this.injectNextEvent()) {
+ try {
+ waitForNewEvents();
+ } catch (Exception e) {
+ e.printStackTrace();
+ break;
}
- }
-
- /**
- * Method to wait for an amount of time when there are no available events.
- * Implementation of EntranceProcessingItem should override this method to
- * implement non-blocking wait or to adjust the amount of time.
- */
- protected void waitForNewEvents() throws Exception {
- Thread.sleep(100);
- }
+ }
+ }
+ }
+
+ /**
+ * Method to wait for an amount of time when there are no available events.
+ * Implementation of EntranceProcessingItem should override this method to
+ * implement non-blocking wait or to adjust the amount of time.
+ */
+ protected void waitForNewEvents() throws Exception {
+ Thread.sleep(100);
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ProcessingItem.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ProcessingItem.java
index 02fb84d..c1ecabc 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ProcessingItem.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/ProcessingItem.java
@@ -28,43 +28,42 @@
*/
public interface ProcessingItem extends IProcessingItem {
- /**
- * Connects this processing item in a round robin fashion. The events will
- * be distributed evenly between the instantiated processing items.
- *
- * @param inputStream
- * Stream to connect this processing item.
- * @return ProcessingItem
- */
- public ProcessingItem connectInputShuffleStream(Stream inputStream);
+ /**
+ * Connects this processing item in a round robin fashion. The events will be
+ * distributed evenly between the instantiated processing items.
+ *
+ * @param inputStream
+ * Stream to connect this processing item.
+ * @return ProcessingItem
+ */
+ public ProcessingItem connectInputShuffleStream(Stream inputStream);
- /**
- * Connects this processing item taking the event key into account. Events
- * will be routed to the processing item according to the modulus of its key
- * and the paralellism level. Ex.: key = 5 and paralellism = 2, 5 mod 2 = 1.
- * Processing item responsible for 1 will receive this event.
- *
- * @param inputStream
- * Stream to connect this processing item.
- * @return ProcessingItem
- */
- public ProcessingItem connectInputKeyStream(Stream inputStream);
+ /**
+ * Connects this processing item taking the event key into account. Events
+ * will be routed to the processing item according to the modulus of its key
+ * and the paralellism level. Ex.: key = 5 and paralellism = 2, 5 mod 2 = 1.
+ * Processing item responsible for 1 will receive this event.
+ *
+ * @param inputStream
+ * Stream to connect this processing item.
+ * @return ProcessingItem
+ */
+ public ProcessingItem connectInputKeyStream(Stream inputStream);
- /**
- * Connects this processing item to the stream in a broadcast fashion. All
- * processing items of this type will receive copy of the original event.
- *
- * @param inputStream
- * Stream to connect this processing item.
- * @return ProcessingItem
- */
- public ProcessingItem connectInputAllStream(Stream inputStream);
+ /**
+ * Connects this processing item to the stream in a broadcast fashion. All
+ * processing items of this type will receive copy of the original event.
+ *
+ * @param inputStream
+ * Stream to connect this processing item.
+ * @return ProcessingItem
+ */
+ public ProcessingItem connectInputAllStream(Stream inputStream);
-
- /**
- * Gets processing item parallelism level.
- *
- * @return int
- */
- public int getParallelism();
+ /**
+ * Gets processing item parallelism level.
+ *
+ * @return int
+ */
+ public int getParallelism();
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/Stream.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/Stream.java
index b496d35..0c54232 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/Stream.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/Stream.java
@@ -24,39 +24,38 @@
/**
* Stream interface.
- *
+ *
* @author severien
- *
+ *
*/
public interface Stream {
-
- /**
- * Puts events into a platform specific data stream.
- *
- * @param event
- */
- public void put(ContentEvent event);
-
- /**
- * Sets the stream id which is represented by a name.
- *
- * @param stream
- */
- //public void setStreamId(String stream);
-
-
- /**
- * Gets stream id.
- *
- * @return id
- */
- public String getStreamId();
-
- /**
- * Set batch size
- *
- * @param batchSize
- * the suggested size for batching messages on this stream
- */
- public void setBatchSize(int batchsize);
+
+ /**
+ * Puts events into a platform specific data stream.
+ *
+ * @param event
+ */
+ public void put(ContentEvent event);
+
+ /**
+ * Sets the stream id which is represented by a name.
+ *
+ * @param stream
+ */
+ // public void setStreamId(String stream);
+
+ /**
+ * Gets stream id.
+ *
+ * @return id
+ */
+ public String getStreamId();
+
+ /**
+ * Set batch size
+ *
+ * @param batchSize
+ * the suggested size for batching messages on this stream
+ */
+ public void setBatchSize(int batchsize);
}
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/Topology.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/Topology.java
index 6ad93ed..dce4974 100755
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/Topology.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/Topology.java
@@ -21,65 +21,63 @@
*/
public interface Topology {
- /*
- * Name
- */
- /**
- * Get the topology's name
- *
- * @return the name of the topology
- */
- public String getTopologyName();
+ /*
+ * Name
+ */
+ /**
+ * Get the topology's name
+ *
+ * @return the name of the topology
+ */
+ public String getTopologyName();
- /**
- * Set the topology's name
- *
- * @param topologyName
- * the name of the topology
- */
- public void setTopologyName(String topologyName) ;
+ /**
+ * Set the topology's name
+ *
+ * @param topologyName
+ * the name of the topology
+ */
+ public void setTopologyName(String topologyName);
- /*
- * Entrance Processing Items
- */
- /**
- * Add an EntranceProcessingItem to this topology
- *
- * @param epi
- * the EntranceProcessingItem to be added
- */
- void addEntranceProcessingItem(EntranceProcessingItem epi);
-
-
- /*
- * Processing Items
- */
- /**
- * Add a ProcessingItem to this topology
- * with default parallelism level (i.e. 1)
- *
- * @param procItem
- * the ProcessingItem to be added
- */
- void addProcessingItem(IProcessingItem procItem);
-
- /**
- * Add a ProcessingItem to this topology
- * with an associated parallelism level
- *
- * @param procItem
- * the ProcessingItem to be added
- * @param parallelismHint
- * the parallelism level
- */
- void addProcessingItem(IProcessingItem procItem, int parallelismHint);
-
- /*
- * Streams
- */
- /**
- *
- * @param stream
- */
- void addStream(Stream stream);
+ /*
+ * Entrance Processing Items
+ */
+ /**
+ * Add an EntranceProcessingItem to this topology
+ *
+ * @param epi
+ * the EntranceProcessingItem to be added
+ */
+ void addEntranceProcessingItem(EntranceProcessingItem epi);
+
+ /*
+ * Processing Items
+ */
+ /**
+ * Add a ProcessingItem to this topology with default parallelism level (i.e.
+ * 1)
+ *
+ * @param procItem
+ * the ProcessingItem to be added
+ */
+ void addProcessingItem(IProcessingItem procItem);
+
+ /**
+ * Add a ProcessingItem to this topology with an associated parallelism level
+ *
+ * @param procItem
+ * the ProcessingItem to be added
+ * @param parallelismHint
+ * the parallelism level
+ */
+ void addProcessingItem(IProcessingItem procItem, int parallelismHint);
+
+ /*
+ * Streams
+ */
+ /**
+ *
+ * @param stream
+ */
+ void addStream(Stream stream);
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/TopologyBuilder.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/TopologyBuilder.java
index aebb136..90b50ea 100755
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/TopologyBuilder.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/topology/TopologyBuilder.java
@@ -32,192 +32,199 @@
*/
public class TopologyBuilder {
- // TODO:
- // Possible options:
- // 1. we may convert this as interface and platform dependent builder will inherit this method
- // 2. refactor by combining TopologyBuilder, ComponentFactory and Topology
- // -ve -> fat class where it has capabilities to instantiate specific component and connecting them
- // +ve -> easy abstraction for SAMOA developer "you just implement your builder logic here!"
- private ComponentFactory componentFactory;
- private Topology topology;
- private Map<Processor, IProcessingItem> mapProcessorToProcessingItem;
+ // TODO:
+ // Possible options:
+ // 1. we may convert this as interface and platform dependent builder will
+ // inherit this method
+ // 2. refactor by combining TopologyBuilder, ComponentFactory and Topology
+ // -ve -> fat class where it has capabilities to instantiate specific
+ // component and connecting them
+ // +ve -> easy abstraction for SAMOA developer
+ // "you just implement your builder logic here!"
+ private ComponentFactory componentFactory;
+ private Topology topology;
+ private Map<Processor, IProcessingItem> mapProcessorToProcessingItem;
- // TODO: refactor, temporary constructor used by Storm code
- public TopologyBuilder() {
- // TODO: initialize _componentFactory using dynamic binding
- // for now, use StormComponentFactory
- // should the factory be Singleton (?)
- // ans: at the moment, no, i.e. each builder will has its associated factory!
- // and the factory will be instantiated using dynamic binding
- // this.componentFactory = new StormComponentFactory();
- }
+ // TODO: refactor, temporary constructor used by Storm code
+ public TopologyBuilder() {
+ // TODO: initialize _componentFactory using dynamic binding
+ // for now, use StormComponentFactory
+ // should the factory be Singleton (?)
+ // ans: at the moment, no, i.e. each builder will has its associated
+ // factory!
+ // and the factory will be instantiated using dynamic binding
+ // this.componentFactory = new StormComponentFactory();
+ }
- // TODO: refactor, temporary constructor used by S4 code
- public TopologyBuilder(ComponentFactory theFactory) {
- this.componentFactory = theFactory;
- }
+ // TODO: refactor, temporary constructor used by S4 code
+ public TopologyBuilder(ComponentFactory theFactory) {
+ this.componentFactory = theFactory;
+ }
- /**
- * Initiates topology with a specific name.
- *
- * @param topologyName
- */
- public void initTopology(String topologyName) {
- this.initTopology(topologyName, 0);
- }
-
- /**
- * Initiates topology with a specific name and a delay between consecutive instances.
- *
- * @param topologyName
- * @param delay
- * delay between injections of two instances from source (in milliseconds)
- */
- public void initTopology(String topologyName, int delay) {
- if (this.topology != null) {
- // TODO: possible refactor this code later
- System.out.println("Topology has been initialized before!");
- return;
- }
- this.topology = componentFactory.createTopology(topologyName);
- }
+ /**
+ * Initiates topology with a specific name.
+ *
+ * @param topologyName
+ */
+ public void initTopology(String topologyName) {
+ this.initTopology(topologyName, 0);
+ }
- /**
- * Returns the platform specific topology.
- *
- * @return
- */
- public Topology build() {
- return topology;
+ /**
+ * Initiates topology with a specific name and a delay between consecutive
+ * instances.
+ *
+ * @param topologyName
+ * @param delay
+ * delay between injections of two instances from source (in
+ * milliseconds)
+ */
+ public void initTopology(String topologyName, int delay) {
+ if (this.topology != null) {
+ // TODO: possible refactor this code later
+ System.out.println("Topology has been initialized before!");
+ return;
}
+ this.topology = componentFactory.createTopology(topologyName);
+ }
- public ProcessingItem addProcessor(Processor processor, int parallelism) {
- ProcessingItem pi = createPi(processor, parallelism);
- if (this.mapProcessorToProcessingItem == null)
- this.mapProcessorToProcessingItem = new HashMap<Processor, IProcessingItem>();
- this.mapProcessorToProcessingItem.put(processor, pi);
- return pi;
- }
+ /**
+ * Returns the platform specific topology.
+ *
+ * @return
+ */
+ public Topology build() {
+ return topology;
+ }
- public ProcessingItem addProcessor(Processor processor) {
- return addProcessor(processor, 1);
- }
+ public ProcessingItem addProcessor(Processor processor, int parallelism) {
+ ProcessingItem pi = createPi(processor, parallelism);
+ if (this.mapProcessorToProcessingItem == null)
+ this.mapProcessorToProcessingItem = new HashMap<Processor, IProcessingItem>();
+ this.mapProcessorToProcessingItem.put(processor, pi);
+ return pi;
+ }
- public ProcessingItem connectInputShuffleStream(Stream inputStream, Processor processor) {
- ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
- Preconditions.checkNotNull(pi, "Trying to connect to null PI");
- return pi.connectInputShuffleStream(inputStream);
- }
+ public ProcessingItem addProcessor(Processor processor) {
+ return addProcessor(processor, 1);
+ }
- public ProcessingItem connectInputKeyStream(Stream inputStream, Processor processor) {
- ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
- Preconditions.checkNotNull(pi, "Trying to connect to null PI");
- return pi.connectInputKeyStream(inputStream);
- }
+ public ProcessingItem connectInputShuffleStream(Stream inputStream, Processor processor) {
+ ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
+ Preconditions.checkNotNull(pi, "Trying to connect to null PI");
+ return pi.connectInputShuffleStream(inputStream);
+ }
- public ProcessingItem connectInputAllStream(Stream inputStream, Processor processor) {
- ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
- Preconditions.checkNotNull(pi, "Trying to connect to null PI");
- return pi.connectInputAllStream(inputStream);
- }
+ public ProcessingItem connectInputKeyStream(Stream inputStream, Processor processor) {
+ ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
+ Preconditions.checkNotNull(pi, "Trying to connect to null PI");
+ return pi.connectInputKeyStream(inputStream);
+ }
- public Stream createInputShuffleStream(Processor processor, Processor dest) {
- Stream inputStream = this.createStream(dest);
- ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
- Preconditions.checkNotNull(pi, "Trying to connect to null PI");
- pi.connectInputShuffleStream(inputStream);
- return inputStream;
- }
+ public ProcessingItem connectInputAllStream(Stream inputStream, Processor processor) {
+ ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
+ Preconditions.checkNotNull(pi, "Trying to connect to null PI");
+ return pi.connectInputAllStream(inputStream);
+ }
- public Stream createInputKeyStream(Processor processor, Processor dest) {
- Stream inputStream = this.createStream(dest);
- ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
- Preconditions.checkNotNull(pi, "Trying to connect to null PI");
- pi.connectInputKeyStream(inputStream);
- return inputStream;
- }
+ public Stream createInputShuffleStream(Processor processor, Processor dest) {
+ Stream inputStream = this.createStream(dest);
+ ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
+ Preconditions.checkNotNull(pi, "Trying to connect to null PI");
+ pi.connectInputShuffleStream(inputStream);
+ return inputStream;
+ }
- public Stream createInputAllStream(Processor processor, Processor dest) {
- Stream inputStream = this.createStream(dest);
- ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
- Preconditions.checkNotNull(pi, "Trying to connect to null PI");
- pi.connectInputAllStream(inputStream);
- return inputStream;
- }
+ public Stream createInputKeyStream(Processor processor, Processor dest) {
+ Stream inputStream = this.createStream(dest);
+ ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
+ Preconditions.checkNotNull(pi, "Trying to connect to null PI");
+ pi.connectInputKeyStream(inputStream);
+ return inputStream;
+ }
- public Stream createStream(Processor processor) {
- IProcessingItem pi = mapProcessorToProcessingItem.get(processor);
- Stream ret = null;
- Preconditions.checkNotNull(pi, "Trying to create stream from null PI");
- ret = this.createStream(pi);
- if (pi instanceof EntranceProcessingItem)
- ((EntranceProcessingItem) pi).setOutputStream(ret);
- return ret;
- }
+ public Stream createInputAllStream(Processor processor, Processor dest) {
+ Stream inputStream = this.createStream(dest);
+ ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
+ Preconditions.checkNotNull(pi, "Trying to connect to null PI");
+ pi.connectInputAllStream(inputStream);
+ return inputStream;
+ }
- public EntranceProcessingItem addEntranceProcessor(EntranceProcessor entranceProcessor) {
- EntranceProcessingItem pi = createEntrancePi(entranceProcessor);
- if (this.mapProcessorToProcessingItem == null)
- this.mapProcessorToProcessingItem = new HashMap<Processor, IProcessingItem>();
- mapProcessorToProcessingItem.put(entranceProcessor, pi);
- return pi;
- }
+ public Stream createStream(Processor processor) {
+ IProcessingItem pi = mapProcessorToProcessingItem.get(processor);
+ Stream ret = null;
+ Preconditions.checkNotNull(pi, "Trying to create stream from null PI");
+ ret = this.createStream(pi);
+ if (pi instanceof EntranceProcessingItem)
+ ((EntranceProcessingItem) pi).setOutputStream(ret);
+ return ret;
+ }
- public ProcessingItem getProcessingItem(Processor processor) {
- ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
- Preconditions.checkNotNull(pi, "Trying to retrieve null PI");
- return pi;
- }
+ public EntranceProcessingItem addEntranceProcessor(EntranceProcessor entranceProcessor) {
+ EntranceProcessingItem pi = createEntrancePi(entranceProcessor);
+ if (this.mapProcessorToProcessingItem == null)
+ this.mapProcessorToProcessingItem = new HashMap<Processor, IProcessingItem>();
+ mapProcessorToProcessingItem.put(entranceProcessor, pi);
+ return pi;
+ }
- /**
- * Creates a processing item with a specific processor and paralellism level of 1.
- *
- * @param processor
- * @return ProcessingItem
- */
- @SuppressWarnings("unused")
- private ProcessingItem createPi(Processor processor) {
- return createPi(processor, 1);
- }
+ public ProcessingItem getProcessingItem(Processor processor) {
+ ProcessingItem pi = (ProcessingItem) mapProcessorToProcessingItem.get(processor);
+ Preconditions.checkNotNull(pi, "Trying to retrieve null PI");
+ return pi;
+ }
- /**
- * Creates a processing item with a specific processor and paralellism level.
- *
- * @param processor
- * @param parallelism
- * @return ProcessingItem
- */
- private ProcessingItem createPi(Processor processor, int parallelism) {
- ProcessingItem pi = this.componentFactory.createPi(processor, parallelism);
- this.topology.addProcessingItem(pi, parallelism);
- return pi;
- }
+ /**
+ * Creates a processing item with a specific processor and paralellism level
+ * of 1.
+ *
+ * @param processor
+ * @return ProcessingItem
+ */
+ @SuppressWarnings("unused")
+ private ProcessingItem createPi(Processor processor) {
+ return createPi(processor, 1);
+ }
- /**
- * Creates a platform specific entrance processing item.
- *
- * @param processor
- * @return
- */
- private EntranceProcessingItem createEntrancePi(EntranceProcessor processor) {
- EntranceProcessingItem epi = this.componentFactory.createEntrancePi(processor);
- this.topology.addEntranceProcessingItem(epi);
- if (this.mapProcessorToProcessingItem == null)
- this.mapProcessorToProcessingItem = new HashMap<Processor, IProcessingItem>();
- this.mapProcessorToProcessingItem.put(processor, epi);
- return epi;
- }
+ /**
+ * Creates a processing item with a specific processor and paralellism level.
+ *
+ * @param processor
+ * @param parallelism
+ * @return ProcessingItem
+ */
+ private ProcessingItem createPi(Processor processor, int parallelism) {
+ ProcessingItem pi = this.componentFactory.createPi(processor, parallelism);
+ this.topology.addProcessingItem(pi, parallelism);
+ return pi;
+ }
- /**
- * Creates a platform specific stream.
- *
- * @param sourcePi
- * source processing item.
- * @return
- */
- private Stream createStream(IProcessingItem sourcePi) {
- Stream stream = this.componentFactory.createStream(sourcePi);
- this.topology.addStream(stream);
- return stream;
- }
+ /**
+ * Creates a platform specific entrance processing item.
+ *
+ * @param processor
+ * @return
+ */
+ private EntranceProcessingItem createEntrancePi(EntranceProcessor processor) {
+ EntranceProcessingItem epi = this.componentFactory.createEntrancePi(processor);
+ this.topology.addEntranceProcessingItem(epi);
+ if (this.mapProcessorToProcessingItem == null)
+ this.mapProcessorToProcessingItem = new HashMap<Processor, IProcessingItem>();
+ this.mapProcessorToProcessingItem.put(processor, epi);
+ return epi;
+ }
+
+ /**
+ * Creates a platform specific stream.
+ *
+ * @param sourcePi
+ * source processing item.
+ * @return
+ */
+ private Stream createStream(IProcessingItem sourcePi) {
+ Stream stream = this.componentFactory.createStream(sourcePi);
+ this.topology.addStream(stream);
+ return stream;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/PartitioningScheme.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/PartitioningScheme.java
index ac6fc3f..457e407 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/PartitioningScheme.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/PartitioningScheme.java
@@ -22,11 +22,12 @@
/**
* Represents the 3 schemes to partition the streams
+ *
* @author Anh Thu Vu
- *
+ *
*/
public enum PartitioningScheme {
- SHUFFLE, GROUP_BY_KEY, BROADCAST
+ SHUFFLE, GROUP_BY_KEY, BROADCAST
}
// TODO: use this enum in S4
// Storm doesn't seem to need this
\ No newline at end of file
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/StreamDestination.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/StreamDestination.java
index 2781fb8..a759b26 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/StreamDestination.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/StreamDestination.java
@@ -23,43 +23,42 @@
import com.yahoo.labs.samoa.topology.IProcessingItem;
/**
- * Represents one destination for streams. It has the info of:
- * the ProcessingItem, parallelismHint, and partitioning scheme.
- * Usage:
- * - When ProcessingItem connects to a stream, it will pass
- * a StreamDestination to the stream.
- * - Stream manages a set of StreamDestination.
- * - Used in single-threaded and multi-threaded local mode.
+ * Represents one destination for streams. It has the info of: the
+ * ProcessingItem, parallelismHint, and partitioning scheme. Usage: - When
+ * ProcessingItem connects to a stream, it will pass a StreamDestination to the
+ * stream. - Stream manages a set of StreamDestination. - Used in
+ * single-threaded and multi-threaded local mode.
+ *
* @author Anh Thu Vu
- *
+ *
*/
public class StreamDestination {
- private IProcessingItem pi;
- private int parallelism;
- private PartitioningScheme type;
-
- /*
- * Constructor
- */
- public StreamDestination(IProcessingItem pi, int parallelismHint, PartitioningScheme type) {
- this.pi = pi;
- this.parallelism = parallelismHint;
- this.type = type;
- }
-
- /*
- * Getters
- */
- public IProcessingItem getProcessingItem() {
- return this.pi;
- }
-
- public int getParallelism() {
- return this.parallelism;
- }
-
- public PartitioningScheme getPartitioningScheme() {
- return this.type;
- }
+ private IProcessingItem pi;
+ private int parallelism;
+ private PartitioningScheme type;
+
+ /*
+ * Constructor
+ */
+ public StreamDestination(IProcessingItem pi, int parallelismHint, PartitioningScheme type) {
+ this.pi = pi;
+ this.parallelism = parallelismHint;
+ this.type = type;
+ }
+
+ /*
+ * Getters
+ */
+ public IProcessingItem getProcessingItem() {
+ return this.pi;
+ }
+
+ public int getParallelism() {
+ return this.parallelism;
+ }
+
+ public PartitioningScheme getPartitioningScheme() {
+ return this.type;
+ }
}
diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/Utils.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/Utils.java
index ce18d78..bd819ad 100644
--- a/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/Utils.java
+++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/utils/Utils.java
@@ -35,148 +35,150 @@
/**
* Utils class for building and deploying applications programmatically.
+ *
* @author severien
- *
+ *
*/
public class Utils {
- public static void buildSamoaPackage() {
- try {
- String output = "/tmp/samoa/samoa.jar";// System.getProperty("user.home") + "/samoa.jar";
- Manifest manifest = createManifest();
+ public static void buildSamoaPackage() {
+ try {
+ String output = "/tmp/samoa/samoa.jar";// System.getProperty("user.home")
+ // + "/samoa.jar";
+ Manifest manifest = createManifest();
- BufferedOutputStream bo;
+ BufferedOutputStream bo;
- bo = new BufferedOutputStream(new FileOutputStream(output));
- JarOutputStream jo = new JarOutputStream(bo, manifest);
+ bo = new BufferedOutputStream(new FileOutputStream(output));
+ JarOutputStream jo = new JarOutputStream(bo, manifest);
- String baseDir = System.getProperty("user.dir");
- System.out.println(baseDir);
-
- File samoaJar = new File(baseDir+"/target/samoa-0.0.1-SNAPSHOT.jar");
- addEntry(jo,samoaJar,baseDir+"/target/","/app/");
- addLibraries(jo);
-
- jo.close();
- bo.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
+ String baseDir = System.getProperty("user.dir");
+ System.out.println(baseDir);
- }
-
- // TODO should get the modules file from the parameters
- public static void buildModulesPackage(List<String> modulesNames) {
- System.out.println(System.getProperty("user.dir"));
- try {
- String baseDir = System.getProperty("user.dir");
- List<File> filesArray = new ArrayList<>();
- for (String module : modulesNames) {
- module = "/"+module.replace(".", "/")+".class";
- filesArray.add(new File(baseDir+module));
- }
- String output = System.getProperty("user.home") + "/modules.jar";
+ File samoaJar = new File(baseDir + "/target/samoa-0.0.1-SNAPSHOT.jar");
+ addEntry(jo, samoaJar, baseDir + "/target/", "/app/");
+ addLibraries(jo);
- Manifest manifest = new Manifest();
- manifest.getMainAttributes().put(Attributes.Name.MANIFEST_VERSION,
- "1.0");
- manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_URL,
- "http://samoa.yahoo.com");
- manifest.getMainAttributes().put(
- Attributes.Name.IMPLEMENTATION_VERSION, "0.1");
- manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_VENDOR,
- "Yahoo");
- manifest.getMainAttributes().put(
- Attributes.Name.IMPLEMENTATION_VENDOR_ID, "SAMOA");
+ jo.close();
+ bo.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
- BufferedOutputStream bo;
+ }
- bo = new BufferedOutputStream(new FileOutputStream(output));
- JarOutputStream jo = new JarOutputStream(bo, manifest);
+ // TODO should get the modules file from the parameters
+ public static void buildModulesPackage(List<String> modulesNames) {
+ System.out.println(System.getProperty("user.dir"));
+ try {
+ String baseDir = System.getProperty("user.dir");
+ List<File> filesArray = new ArrayList<>();
+ for (String module : modulesNames) {
+ module = "/" + module.replace(".", "/") + ".class";
+ filesArray.add(new File(baseDir + module));
+ }
+ String output = System.getProperty("user.home") + "/modules.jar";
- File[] files = filesArray.toArray(new File[filesArray.size()]);
- addEntries(jo,files, baseDir, "");
+ Manifest manifest = new Manifest();
+ manifest.getMainAttributes().put(Attributes.Name.MANIFEST_VERSION,
+ "1.0");
+ manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_URL,
+ "http://samoa.yahoo.com");
+ manifest.getMainAttributes().put(
+ Attributes.Name.IMPLEMENTATION_VERSION, "0.1");
+ manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_VENDOR,
+ "Yahoo");
+ manifest.getMainAttributes().put(
+ Attributes.Name.IMPLEMENTATION_VENDOR_ID, "SAMOA");
- jo.close();
- bo.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
+ BufferedOutputStream bo;
- }
+ bo = new BufferedOutputStream(new FileOutputStream(output));
+ JarOutputStream jo = new JarOutputStream(bo, manifest);
- private static void addLibraries(JarOutputStream jo) {
- try {
- String baseDir = System.getProperty("user.dir");
- String libDir = baseDir+"/target/lib";
- File inputFile = new File(libDir);
-
- File[] files = inputFile.listFiles();
- for (File file : files) {
- addEntry(jo, file, baseDir, "lib");
- }
- jo.close();
-
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
-
- private static void addEntries(JarOutputStream jo, File[] files, String baseDir, String rootDir){
- for (File file : files) {
+ File[] files = filesArray.toArray(new File[filesArray.size()]);
+ addEntries(jo, files, baseDir, "");
- if (!file.isDirectory()) {
- addEntry(jo, file, baseDir, rootDir);
- } else {
- File dir = new File(file.getAbsolutePath());
- addEntries(jo, dir.listFiles(), baseDir, rootDir);
- }
- }
- }
-
- private static void addEntry(JarOutputStream jo, File file, String baseDir, String rootDir) {
- try {
- BufferedInputStream bi = new BufferedInputStream(new FileInputStream(file));
+ jo.close();
+ bo.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
- String path = file.getAbsolutePath().replaceFirst(baseDir, rootDir);
- jo.putNextEntry(new ZipEntry(path));
+ }
- byte[] buf = new byte[1024];
- int anz;
- while ((anz = bi.read(buf)) != -1) {
- jo.write(buf, 0, anz);
- }
- bi.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
+ private static void addLibraries(JarOutputStream jo) {
+ try {
+ String baseDir = System.getProperty("user.dir");
+ String libDir = baseDir + "/target/lib";
+ File inputFile = new File(libDir);
- public static Manifest createManifest() {
- Manifest manifest = new Manifest();
- manifest.getMainAttributes().put(Attributes.Name.MANIFEST_VERSION, "1.0");
- manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_URL, "http://samoa.yahoo.com");
- manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_VERSION, "0.1");
- manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_VENDOR, "Yahoo");
- manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_VENDOR_ID, "SAMOA");
- Attributes s4Attributes = new Attributes();
- s4Attributes.putValue("S4-App-Class", "path.to.Class");
- Attributes.Name name = new Attributes.Name("S4-App-Class");
- Attributes.Name S4Version = new Attributes.Name("S4-Version");
- manifest.getMainAttributes().put(name, "samoa.topology.impl.DoTaskApp");
- manifest.getMainAttributes().put(S4Version, "0.6.0-incubating");
- return manifest;
- }
-
- public static Object getInstance(String className) {
+ File[] files = inputFile.listFiles();
+ for (File file : files) {
+ addEntry(jo, file, baseDir, "lib");
+ }
+ jo.close();
+
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ private static void addEntries(JarOutputStream jo, File[] files, String baseDir, String rootDir) {
+ for (File file : files) {
+
+ if (!file.isDirectory()) {
+ addEntry(jo, file, baseDir, rootDir);
+ } else {
+ File dir = new File(file.getAbsolutePath());
+ addEntries(jo, dir.listFiles(), baseDir, rootDir);
+ }
+ }
+ }
+
+ private static void addEntry(JarOutputStream jo, File file, String baseDir, String rootDir) {
+ try {
+ BufferedInputStream bi = new BufferedInputStream(new FileInputStream(file));
+
+ String path = file.getAbsolutePath().replaceFirst(baseDir, rootDir);
+ jo.putNextEntry(new ZipEntry(path));
+
+ byte[] buf = new byte[1024];
+ int anz;
+ while ((anz = bi.read(buf)) != -1) {
+ jo.write(buf, 0, anz);
+ }
+ bi.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ public static Manifest createManifest() {
+ Manifest manifest = new Manifest();
+ manifest.getMainAttributes().put(Attributes.Name.MANIFEST_VERSION, "1.0");
+ manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_URL, "http://samoa.yahoo.com");
+ manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_VERSION, "0.1");
+ manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_VENDOR, "Yahoo");
+ manifest.getMainAttributes().put(Attributes.Name.IMPLEMENTATION_VENDOR_ID, "SAMOA");
+ Attributes s4Attributes = new Attributes();
+ s4Attributes.putValue("S4-App-Class", "path.to.Class");
+ Attributes.Name name = new Attributes.Name("S4-App-Class");
+ Attributes.Name S4Version = new Attributes.Name("S4-Version");
+ manifest.getMainAttributes().put(name, "samoa.topology.impl.DoTaskApp");
+ manifest.getMainAttributes().put(S4Version, "0.6.0-incubating");
+ return manifest;
+ }
+
+ public static Object getInstance(String className) {
Class<?> cls;
- Object obj = null;
- try {
- cls = Class.forName(className);
- obj = cls.newInstance();
- } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
- e.printStackTrace();
- }
- return obj;
- }
+ Object obj = null;
+ try {
+ cls = Class.forName(className);
+ obj = cls.newInstance();
+ } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
+ e.printStackTrace();
+ }
+ return obj;
+ }
}
diff --git a/samoa-api/src/test/java/com/yahoo/labs/samoa/core/DoubleVectorTest.java b/samoa-api/src/test/java/com/yahoo/labs/samoa/core/DoubleVectorTest.java
index f82588b..e8d589f 100644
--- a/samoa-api/src/test/java/com/yahoo/labs/samoa/core/DoubleVectorTest.java
+++ b/samoa-api/src/test/java/com/yahoo/labs/samoa/core/DoubleVectorTest.java
@@ -27,71 +27,78 @@
import org.junit.Test;
public class DoubleVectorTest {
- private DoubleVector emptyVector, array5Vector;
+ private DoubleVector emptyVector, array5Vector;
- @Before
- public void setUp() {
- emptyVector = new DoubleVector();
- array5Vector = new DoubleVector(new double[] { 1.1, 2.5, 0, 4.7, 0 });
- }
+ @Before
+ public void setUp() {
+ emptyVector = new DoubleVector();
+ array5Vector = new DoubleVector(new double[] { 1.1, 2.5, 0, 4.7, 0 });
+ }
- @Test
- public void testGetArrayRef() {
- assertThat(emptyVector.getArrayRef(), notNullValue());
- assertTrue(emptyVector.getArrayRef() == emptyVector.getArrayRef());
- assertEquals(5, array5Vector.getArrayRef().length);
- }
+ @Test
+ public void testGetArrayRef() {
+ assertThat(emptyVector.getArrayRef(), notNullValue());
+ assertTrue(emptyVector.getArrayRef() == emptyVector.getArrayRef());
+ assertEquals(5, array5Vector.getArrayRef().length);
+ }
- @Test
- public void testGetArrayCopy() {
- double[] arrayRef;
- arrayRef = emptyVector.getArrayRef();
- assertTrue(arrayRef != emptyVector.getArrayCopy());
- assertThat(arrayRef, is(equalTo(emptyVector.getArrayCopy())));
+ @Test
+ public void testGetArrayCopy() {
+ double[] arrayRef;
+ arrayRef = emptyVector.getArrayRef();
+ assertTrue(arrayRef != emptyVector.getArrayCopy());
+ assertThat(arrayRef, is(equalTo(emptyVector.getArrayCopy())));
- arrayRef = array5Vector.getArrayRef();
- assertTrue(arrayRef != array5Vector.getArrayCopy());
- assertThat(arrayRef, is(equalTo(array5Vector.getArrayCopy())));
- }
+ arrayRef = array5Vector.getArrayRef();
+ assertTrue(arrayRef != array5Vector.getArrayCopy());
+ assertThat(arrayRef, is(equalTo(array5Vector.getArrayCopy())));
+ }
- @Test
- public void testNumNonZeroEntries() {
- assertEquals(0, emptyVector.numNonZeroEntries());
- assertEquals(3, array5Vector.numNonZeroEntries());
- }
+ @Test
+ public void testNumNonZeroEntries() {
+ assertEquals(0, emptyVector.numNonZeroEntries());
+ assertEquals(3, array5Vector.numNonZeroEntries());
+ }
- @Test(expected = IndexOutOfBoundsException.class)
- public void testGetValueOutOfBound() {
- @SuppressWarnings("unused")
- double value = emptyVector.getArrayRef()[0];
- }
+ @Test(expected = IndexOutOfBoundsException.class)
+ public void testGetValueOutOfBound() {
+ @SuppressWarnings("unused")
+ double value = emptyVector.getArrayRef()[0];
+ }
- @Test()
- public void testSetValue() {
- // test automatic vector enlargement
- emptyVector.setValue(0, 1.0);
- assertEquals(1, emptyVector.getArrayRef().length);
- assertEquals(1.0, emptyVector.getArrayRef()[0], 0.0); // should be exactly the same, so delta=0.0
+ @Test()
+ public void testSetValue() {
+ // test automatic vector enlargement
+ emptyVector.setValue(0, 1.0);
+ assertEquals(1, emptyVector.getArrayRef().length);
+ assertEquals(1.0, emptyVector.getArrayRef()[0], 0.0); // should be exactly
+ // the same, so
+ // delta=0.0
- emptyVector.setValue(5, 5.5);
- assertEquals(6, emptyVector.getArrayRef().length);
- assertEquals(2, emptyVector.numNonZeroEntries());
- assertEquals(5.5, emptyVector.getArrayRef()[5], 0.0); // should be exactly the same, so delta=0.0
- }
+ emptyVector.setValue(5, 5.5);
+ assertEquals(6, emptyVector.getArrayRef().length);
+ assertEquals(2, emptyVector.numNonZeroEntries());
+ assertEquals(5.5, emptyVector.getArrayRef()[5], 0.0); // should be exactly
+ // the same, so
+ // delta=0.0
+ }
- @Test
- public void testAddToValue() {
- array5Vector.addToValue(2, 5.0);
- assertEquals(5, array5Vector.getArrayRef()[2], 0.0); // should be exactly the same, so delta=0.0
+ @Test
+ public void testAddToValue() {
+ array5Vector.addToValue(2, 5.0);
+ assertEquals(5, array5Vector.getArrayRef()[2], 0.0); // should be exactly
+ // the same, so
+ // delta=0.0
- // test automatic vector enlargement
- emptyVector.addToValue(0, 1.0);
- assertEquals(1, emptyVector.getArrayRef()[0], 0.0); // should be exactly the same, so delta=0.0
- }
+ // test automatic vector enlargement
+ emptyVector.addToValue(0, 1.0);
+ assertEquals(1, emptyVector.getArrayRef()[0], 0.0); // should be exactly the
+ // same, so delta=0.0
+ }
- @Test
- public void testSumOfValues() {
- assertEquals(1.1 + 2.5 + 4.7, array5Vector.sumOfValues(), Double.MIN_NORMAL);
- }
+ @Test
+ public void testSumOfValues() {
+ assertEquals(1.1 + 2.5 + 4.7, array5Vector.sumOfValues(), Double.MIN_NORMAL);
+ }
}
diff --git a/samoa-api/src/test/java/com/yahoo/labs/samoa/streams/fs/HDFSFileStreamSourceTest.java b/samoa-api/src/test/java/com/yahoo/labs/samoa/streams/fs/HDFSFileStreamSourceTest.java
index 04f3184..51ec57d 100644
--- a/samoa-api/src/test/java/com/yahoo/labs/samoa/streams/fs/HDFSFileStreamSourceTest.java
+++ b/samoa-api/src/test/java/com/yahoo/labs/samoa/streams/fs/HDFSFileStreamSourceTest.java
@@ -47,257 +47,260 @@
import org.junit.Test;
public class HDFSFileStreamSourceTest {
-
- private static final String[] HOSTS = {"localhost"};
- private static final String BASE_DIR = "/minidfsTest";
- private static final int NUM_FILES_IN_DIR = 4;
- private static final int NUM_NOISE_FILES_IN_DIR = 2;
-
- private HDFSFileStreamSource streamSource;
- private Configuration config;
- private MiniDFSCluster hdfsCluster;
- private String hdfsURI;
+ private static final String[] HOSTS = { "localhost" };
+ private static final String BASE_DIR = "/minidfsTest";
+ private static final int NUM_FILES_IN_DIR = 4;
+ private static final int NUM_NOISE_FILES_IN_DIR = 2;
- @Before
- public void setUp() throws Exception {
- // Start MiniDFSCluster
- MiniDFSCluster.Builder builder = new MiniDFSCluster.Builder(new Configuration()).hosts(HOSTS).numDataNodes(1).format(true);
- hdfsCluster = builder.build();
- hdfsCluster.waitActive();
- hdfsURI = "hdfs://localhost:"+ hdfsCluster.getNameNodePort();
-
- // Construct stream source
- streamSource = new HDFSFileStreamSource();
-
- // General config
- config = new Configuration();
- config.set("fs.defaultFS",hdfsURI);
- }
+ private HDFSFileStreamSource streamSource;
- @After
- public void tearDown() throws Exception {
- hdfsCluster.shutdown();
- }
+ private Configuration config;
+ private MiniDFSCluster hdfsCluster;
+ private String hdfsURI;
- /*
- * Init tests
- */
- @Test
- public void testInitWithSingleFileAndExtension() {
- // write input file
- writeSimpleFiles(BASE_DIR,"txt",1);
-
- // init with path to input file
- streamSource.init(config, BASE_DIR+"/1.txt", "txt");
-
- //assertions
- assertEquals("Size of filePaths is not correct.", 1,streamSource.getFilePathListSize(),0);
- String fn = streamSource.getFilePathAt(0);
- assertTrue("Incorrect file in filePaths.",fn.equals(BASE_DIR+"/1.txt") || fn.equals(hdfsURI+BASE_DIR+"1.txt"));
- }
-
- @Test
- public void testInitWithSingleFileAndNullExtension() {
- // write input file
- writeSimpleFiles(BASE_DIR,"txt",1);
-
- // init with path to input file
- streamSource.init(config, BASE_DIR+"/1.txt", null);
-
- // assertions
- assertEquals("Size of filePaths is not correct.", 1,streamSource.getFilePathListSize(),0);
- String fn = streamSource.getFilePathAt(0);
- assertTrue("Incorrect file in filePaths.",fn.equals(BASE_DIR+"/1.txt") || fn.equals(hdfsURI+BASE_DIR+"1.txt"));
- }
-
- @Test
- public void testInitWithFolderAndExtension() {
- // write input files & noise files
- writeSimpleFiles(BASE_DIR,"txt",NUM_FILES_IN_DIR);
- writeSimpleFiles(BASE_DIR,null,NUM_NOISE_FILES_IN_DIR);
-
- // init with path to input dir
- streamSource.init(config, BASE_DIR, "txt");
-
- // assertions
- assertEquals("Size of filePaths is not correct.", NUM_FILES_IN_DIR,streamSource.getFilePathListSize(),0);
- Set<String> filenames = new HashSet<String>();
- for (int i=1; i<=NUM_FILES_IN_DIR; i++) {
- String targetFn = BASE_DIR+"/"+Integer.toString(i)+".txt";
- filenames.add(targetFn);
- filenames.add(hdfsURI+targetFn);
- }
- for (int i=0; i<NUM_FILES_IN_DIR; i++) {
- String fn = streamSource.getFilePathAt(i);
- assertTrue("Incorrect file in filePaths:"+fn,filenames.contains(fn));
- }
- }
-
- @Test
- public void testInitWithFolderAndNullExtension() {
- // write input file
- writeSimpleFiles(BASE_DIR,null,NUM_FILES_IN_DIR);
-
- // init with path to input dir
- streamSource.init(config, BASE_DIR, null);
-
- // assertions
- assertEquals("Size of filePaths is not correct.", NUM_FILES_IN_DIR,streamSource.getFilePathListSize(),0);
- Set<String> filenames = new HashSet<String>();
- for (int i=1; i<=NUM_FILES_IN_DIR; i++) {
- String targetFn = BASE_DIR+"/"+Integer.toString(i);
- filenames.add(targetFn);
- filenames.add(hdfsURI+targetFn);
- }
- for (int i=0; i< NUM_FILES_IN_DIR; i++) {
- String fn = streamSource.getFilePathAt(i);
- assertTrue("Incorrect file in filePaths:"+fn,filenames.contains(fn));
- }
- }
-
- /*
- * getNextInputStream tests
- */
- @Test
- public void testGetNextInputStream() {
- // write input files & noise files
- writeSimpleFiles(BASE_DIR,"txt",NUM_FILES_IN_DIR);
-
- // init with path to input dir
- streamSource.init(config, BASE_DIR, "txt");
-
- // call getNextInputStream & assertions
- Set<String> contents = new HashSet<String>();
- for (int i=1; i<=NUM_FILES_IN_DIR; i++) {
- contents.add(Integer.toString(i));
- }
- for (int i=0; i< NUM_FILES_IN_DIR; i++) {
- InputStream inStream = streamSource.getNextInputStream();
- assertNotNull("Unexpected end of input stream list.",inStream);
-
- BufferedReader rd = new BufferedReader(new InputStreamReader(inStream));
- String inputRead = null;
- try {
- inputRead = rd.readLine();
- } catch (IOException ioe) {
- fail("Fail reading from stream at index:"+i + ioe.getMessage());
- }
- assertTrue("File content is incorrect.",contents.contains(inputRead));
- Iterator<String> it = contents.iterator();
- while (it.hasNext()) {
- if (it.next().equals(inputRead)) {
- it.remove();
- break;
- }
- }
- }
-
- // assert that another call to getNextInputStream will return null
- assertNull("Call getNextInputStream after the last file did not return null.",streamSource.getNextInputStream());
- }
-
- /*
- * getCurrentInputStream tests
- */
- public void testGetCurrentInputStream() {
- // write input files & noise files
- writeSimpleFiles(BASE_DIR,"txt",NUM_FILES_IN_DIR);
-
- // init with path to input dir
- streamSource.init(config, BASE_DIR, "txt");
-
- // call getNextInputStream, getCurrentInputStream & assertions
- for (int i=0; i<= NUM_FILES_IN_DIR; i++) { // test also after-end-of-list
- InputStream inStream1 = streamSource.getNextInputStream();
- InputStream inStream2 = streamSource.getCurrentInputStream();
- assertSame("Incorrect current input stream.",inStream1, inStream2);
- }
- }
-
- /*
- * reset tests
- */
- public void testReset() {
- // write input files & noise files
- writeSimpleFiles(BASE_DIR,"txt",NUM_FILES_IN_DIR);
-
- // init with path to input dir
- streamSource.init(config, BASE_DIR, "txt");
-
- // Get the first input string
- InputStream firstInStream = streamSource.getNextInputStream();
- String firstInput = null;
- assertNotNull("Unexpected end of input stream list.",firstInStream);
-
- BufferedReader rd1 = new BufferedReader(new InputStreamReader(firstInStream));
- try {
- firstInput = rd1.readLine();
- } catch (IOException ioe) {
- fail("Fail reading from stream at index:0" + ioe.getMessage());
- }
-
- // call getNextInputStream a few times
- streamSource.getNextInputStream();
-
- // call reset, call next, assert that output is 1 (the first file)
- try {
- streamSource.reset();
- } catch (IOException ioe) {
- fail("Fail resetting stream source." + ioe.getMessage());
- }
-
- InputStream inStream = streamSource.getNextInputStream();
- assertNotNull("Unexpected end of input stream list.",inStream);
-
- BufferedReader rd2 = new BufferedReader(new InputStreamReader(inStream));
- String inputRead = null;
- try {
- inputRead = rd2.readLine();
- } catch (IOException ioe) {
- fail("Fail reading from stream at index:0" + ioe.getMessage());
- }
- assertEquals("File content is incorrect.",firstInput,inputRead);
- }
-
- private void writeSimpleFiles(String path, String ext, int numOfFiles) {
- // get filesystem
- FileSystem dfs;
- try {
- dfs = hdfsCluster.getFileSystem();
- } catch (IOException ioe) {
- fail("Could not access MiniDFSCluster" + ioe.getMessage());
- return;
- }
-
- // create basedir
- Path basedir = new Path(path);
- try {
- dfs.mkdirs(basedir);
- } catch (IOException ioe) {
- fail("Could not create DIR:"+ path + "\n" + ioe.getMessage());
- return;
- }
-
- // write files
- for (int i=1; i<=numOfFiles; i++) {
- String fn = null;
- if (ext != null) {
- fn = Integer.toString(i) + "."+ ext;
- } else {
- fn = Integer.toString(i);
- }
-
- try {
- OutputStream fin = dfs.create(new Path(path,fn));
- BufferedWriter wr = new BufferedWriter(new OutputStreamWriter(fin));
- wr.write(Integer.toString(i));
- wr.close();
- fin.close();
- } catch (IOException ioe) {
- fail("Fail writing to input file: "+ fn + " in directory: " + path + ioe.getMessage());
- }
- }
- }
+ @Before
+ public void setUp() throws Exception {
+ // Start MiniDFSCluster
+ MiniDFSCluster.Builder builder = new MiniDFSCluster.Builder(new Configuration()).hosts(HOSTS).numDataNodes(1)
+ .format(true);
+ hdfsCluster = builder.build();
+ hdfsCluster.waitActive();
+ hdfsURI = "hdfs://localhost:" + hdfsCluster.getNameNodePort();
+
+ // Construct stream source
+ streamSource = new HDFSFileStreamSource();
+
+ // General config
+ config = new Configuration();
+ config.set("fs.defaultFS", hdfsURI);
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ hdfsCluster.shutdown();
+ }
+
+ /*
+ * Init tests
+ */
+ @Test
+ public void testInitWithSingleFileAndExtension() {
+ // write input file
+ writeSimpleFiles(BASE_DIR, "txt", 1);
+
+ // init with path to input file
+ streamSource.init(config, BASE_DIR + "/1.txt", "txt");
+
+ // assertions
+ assertEquals("Size of filePaths is not correct.", 1, streamSource.getFilePathListSize(), 0);
+ String fn = streamSource.getFilePathAt(0);
+ assertTrue("Incorrect file in filePaths.",
+ fn.equals(BASE_DIR + "/1.txt") || fn.equals(hdfsURI + BASE_DIR + "1.txt"));
+ }
+
+ @Test
+ public void testInitWithSingleFileAndNullExtension() {
+ // write input file
+ writeSimpleFiles(BASE_DIR, "txt", 1);
+
+ // init with path to input file
+ streamSource.init(config, BASE_DIR + "/1.txt", null);
+
+ // assertions
+ assertEquals("Size of filePaths is not correct.", 1, streamSource.getFilePathListSize(), 0);
+ String fn = streamSource.getFilePathAt(0);
+ assertTrue("Incorrect file in filePaths.",
+ fn.equals(BASE_DIR + "/1.txt") || fn.equals(hdfsURI + BASE_DIR + "1.txt"));
+ }
+
+ @Test
+ public void testInitWithFolderAndExtension() {
+ // write input files & noise files
+ writeSimpleFiles(BASE_DIR, "txt", NUM_FILES_IN_DIR);
+ writeSimpleFiles(BASE_DIR, null, NUM_NOISE_FILES_IN_DIR);
+
+ // init with path to input dir
+ streamSource.init(config, BASE_DIR, "txt");
+
+ // assertions
+ assertEquals("Size of filePaths is not correct.", NUM_FILES_IN_DIR, streamSource.getFilePathListSize(), 0);
+ Set<String> filenames = new HashSet<String>();
+ for (int i = 1; i <= NUM_FILES_IN_DIR; i++) {
+ String targetFn = BASE_DIR + "/" + Integer.toString(i) + ".txt";
+ filenames.add(targetFn);
+ filenames.add(hdfsURI + targetFn);
+ }
+ for (int i = 0; i < NUM_FILES_IN_DIR; i++) {
+ String fn = streamSource.getFilePathAt(i);
+ assertTrue("Incorrect file in filePaths:" + fn, filenames.contains(fn));
+ }
+ }
+
+ @Test
+ public void testInitWithFolderAndNullExtension() {
+ // write input file
+ writeSimpleFiles(BASE_DIR, null, NUM_FILES_IN_DIR);
+
+ // init with path to input dir
+ streamSource.init(config, BASE_DIR, null);
+
+ // assertions
+ assertEquals("Size of filePaths is not correct.", NUM_FILES_IN_DIR, streamSource.getFilePathListSize(), 0);
+ Set<String> filenames = new HashSet<String>();
+ for (int i = 1; i <= NUM_FILES_IN_DIR; i++) {
+ String targetFn = BASE_DIR + "/" + Integer.toString(i);
+ filenames.add(targetFn);
+ filenames.add(hdfsURI + targetFn);
+ }
+ for (int i = 0; i < NUM_FILES_IN_DIR; i++) {
+ String fn = streamSource.getFilePathAt(i);
+ assertTrue("Incorrect file in filePaths:" + fn, filenames.contains(fn));
+ }
+ }
+
+ /*
+ * getNextInputStream tests
+ */
+ @Test
+ public void testGetNextInputStream() {
+ // write input files & noise files
+ writeSimpleFiles(BASE_DIR, "txt", NUM_FILES_IN_DIR);
+
+ // init with path to input dir
+ streamSource.init(config, BASE_DIR, "txt");
+
+ // call getNextInputStream & assertions
+ Set<String> contents = new HashSet<String>();
+ for (int i = 1; i <= NUM_FILES_IN_DIR; i++) {
+ contents.add(Integer.toString(i));
+ }
+ for (int i = 0; i < NUM_FILES_IN_DIR; i++) {
+ InputStream inStream = streamSource.getNextInputStream();
+ assertNotNull("Unexpected end of input stream list.", inStream);
+
+ BufferedReader rd = new BufferedReader(new InputStreamReader(inStream));
+ String inputRead = null;
+ try {
+ inputRead = rd.readLine();
+ } catch (IOException ioe) {
+ fail("Fail reading from stream at index:" + i + ioe.getMessage());
+ }
+ assertTrue("File content is incorrect.", contents.contains(inputRead));
+ Iterator<String> it = contents.iterator();
+ while (it.hasNext()) {
+ if (it.next().equals(inputRead)) {
+ it.remove();
+ break;
+ }
+ }
+ }
+
+ // assert that another call to getNextInputStream will return null
+ assertNull("Call getNextInputStream after the last file did not return null.", streamSource.getNextInputStream());
+ }
+
+ /*
+ * getCurrentInputStream tests
+ */
+ public void testGetCurrentInputStream() {
+ // write input files & noise files
+ writeSimpleFiles(BASE_DIR, "txt", NUM_FILES_IN_DIR);
+
+ // init with path to input dir
+ streamSource.init(config, BASE_DIR, "txt");
+
+ // call getNextInputStream, getCurrentInputStream & assertions
+ for (int i = 0; i <= NUM_FILES_IN_DIR; i++) { // test also after-end-of-list
+ InputStream inStream1 = streamSource.getNextInputStream();
+ InputStream inStream2 = streamSource.getCurrentInputStream();
+ assertSame("Incorrect current input stream.", inStream1, inStream2);
+ }
+ }
+
+ /*
+ * reset tests
+ */
+ public void testReset() {
+ // write input files & noise files
+ writeSimpleFiles(BASE_DIR, "txt", NUM_FILES_IN_DIR);
+
+ // init with path to input dir
+ streamSource.init(config, BASE_DIR, "txt");
+
+ // Get the first input string
+ InputStream firstInStream = streamSource.getNextInputStream();
+ String firstInput = null;
+ assertNotNull("Unexpected end of input stream list.", firstInStream);
+
+ BufferedReader rd1 = new BufferedReader(new InputStreamReader(firstInStream));
+ try {
+ firstInput = rd1.readLine();
+ } catch (IOException ioe) {
+ fail("Fail reading from stream at index:0" + ioe.getMessage());
+ }
+
+ // call getNextInputStream a few times
+ streamSource.getNextInputStream();
+
+ // call reset, call next, assert that output is 1 (the first file)
+ try {
+ streamSource.reset();
+ } catch (IOException ioe) {
+ fail("Fail resetting stream source." + ioe.getMessage());
+ }
+
+ InputStream inStream = streamSource.getNextInputStream();
+ assertNotNull("Unexpected end of input stream list.", inStream);
+
+ BufferedReader rd2 = new BufferedReader(new InputStreamReader(inStream));
+ String inputRead = null;
+ try {
+ inputRead = rd2.readLine();
+ } catch (IOException ioe) {
+ fail("Fail reading from stream at index:0" + ioe.getMessage());
+ }
+ assertEquals("File content is incorrect.", firstInput, inputRead);
+ }
+
+ private void writeSimpleFiles(String path, String ext, int numOfFiles) {
+ // get filesystem
+ FileSystem dfs;
+ try {
+ dfs = hdfsCluster.getFileSystem();
+ } catch (IOException ioe) {
+ fail("Could not access MiniDFSCluster" + ioe.getMessage());
+ return;
+ }
+
+ // create basedir
+ Path basedir = new Path(path);
+ try {
+ dfs.mkdirs(basedir);
+ } catch (IOException ioe) {
+ fail("Could not create DIR:" + path + "\n" + ioe.getMessage());
+ return;
+ }
+
+ // write files
+ for (int i = 1; i <= numOfFiles; i++) {
+ String fn = null;
+ if (ext != null) {
+ fn = Integer.toString(i) + "." + ext;
+ } else {
+ fn = Integer.toString(i);
+ }
+
+ try {
+ OutputStream fin = dfs.create(new Path(path, fn));
+ BufferedWriter wr = new BufferedWriter(new OutputStreamWriter(fin));
+ wr.write(Integer.toString(i));
+ wr.close();
+ fin.close();
+ } catch (IOException ioe) {
+ fail("Fail writing to input file: " + fn + " in directory: " + path + ioe.getMessage());
+ }
+ }
+ }
}
diff --git a/samoa-api/src/test/java/com/yahoo/labs/samoa/streams/fs/LocalFileStreamSourceTest.java b/samoa-api/src/test/java/com/yahoo/labs/samoa/streams/fs/LocalFileStreamSourceTest.java
index 21ca378..b121425 100644
--- a/samoa-api/src/test/java/com/yahoo/labs/samoa/streams/fs/LocalFileStreamSourceTest.java
+++ b/samoa-api/src/test/java/com/yahoo/labs/samoa/streams/fs/LocalFileStreamSourceTest.java
@@ -43,234 +43,234 @@
import org.apache.commons.io.FileUtils;
public class LocalFileStreamSourceTest {
- private static final String BASE_DIR = "localfsTest";
- private static final int NUM_FILES_IN_DIR = 4;
- private static final int NUM_NOISE_FILES_IN_DIR = 2;
-
- private LocalFileStreamSource streamSource;
+ private static final String BASE_DIR = "localfsTest";
+ private static final int NUM_FILES_IN_DIR = 4;
+ private static final int NUM_NOISE_FILES_IN_DIR = 2;
- @Before
- public void setUp() throws Exception {
- streamSource = new LocalFileStreamSource();
-
- }
+ private LocalFileStreamSource streamSource;
- @After
- public void tearDown() throws Exception {
- FileUtils.deleteDirectory(new File(BASE_DIR));
- }
+ @Before
+ public void setUp() throws Exception {
+ streamSource = new LocalFileStreamSource();
- @Test
- public void testInitWithSingleFileAndExtension() {
- // write input file
- writeSimpleFiles(BASE_DIR,"txt",1);
-
- // init with path to input file
- File inFile = new File(BASE_DIR,"1.txt");
- String inFilePath = inFile.getAbsolutePath();
- streamSource.init(inFilePath, "txt");
-
- //assertions
- assertEquals("Size of filePaths is not correct.", 1,streamSource.getFilePathListSize(),0);
- String fn = streamSource.getFilePathAt(0);
- assertEquals("Incorrect file in filePaths.",inFilePath,fn);
- }
-
- @Test
- public void testInitWithSingleFileAndNullExtension() {
- // write input file
- writeSimpleFiles(BASE_DIR,"txt",1);
-
- // init with path to input file
- File inFile = new File(BASE_DIR,"1.txt");
- String inFilePath = inFile.getAbsolutePath();
- streamSource.init(inFilePath, null);
-
- //assertions
- assertEquals("Size of filePaths is not correct.", 1,streamSource.getFilePathListSize(),0);
- String fn = streamSource.getFilePathAt(0);
- assertEquals("Incorrect file in filePaths.",inFilePath,fn);
- }
-
- @Test
- public void testInitWithFolderAndExtension() {
- // write input file
- writeSimpleFiles(BASE_DIR,null,NUM_NOISE_FILES_IN_DIR);
- writeSimpleFiles(BASE_DIR,"txt",NUM_FILES_IN_DIR);
-
- // init with path to input dir
- File inDir = new File(BASE_DIR);
- String inDirPath = inDir.getAbsolutePath();
- streamSource.init(inDirPath, "txt");
-
- //assertions
- assertEquals("Size of filePaths is not correct.", NUM_FILES_IN_DIR,streamSource.getFilePathListSize(),0);
- Set<String> filenames = new HashSet<String>();
- for (int i=1; i<=NUM_FILES_IN_DIR; i++) {
- String expectedFn = (new File(inDirPath,Integer.toString(i)+".txt")).getAbsolutePath();
- filenames.add(expectedFn);
- }
- for (int i=0; i< NUM_FILES_IN_DIR; i++) {
- String fn = streamSource.getFilePathAt(i);
- assertTrue("Incorrect file in filePaths:"+fn,filenames.contains(fn));
- }
- }
-
- @Test
- public void testInitWithFolderAndNullExtension() {
- // write input file
- writeSimpleFiles(BASE_DIR,null,NUM_FILES_IN_DIR);
-
- // init with path to input dir
- File inDir = new File(BASE_DIR);
- String inDirPath = inDir.getAbsolutePath();
- streamSource.init(inDirPath, null);
-
- //assertions
- assertEquals("Size of filePaths is not correct.", NUM_FILES_IN_DIR,streamSource.getFilePathListSize(),0);
- Set<String> filenames = new HashSet<String>();
- for (int i=1; i<=NUM_FILES_IN_DIR; i++) {
- String expectedFn = (new File(inDirPath,Integer.toString(i))).getAbsolutePath();
- filenames.add(expectedFn);
- }
- for (int i=0; i< NUM_FILES_IN_DIR; i++) {
- String fn = streamSource.getFilePathAt(i);
- assertTrue("Incorrect file in filePaths:"+fn,filenames.contains(fn));
- }
- }
-
- /*
- * getNextInputStream tests
- */
- @Test
- public void testGetNextInputStream() {
- // write input files & noise files
- writeSimpleFiles(BASE_DIR,"txt",NUM_FILES_IN_DIR);
-
- // init with path to input dir
- streamSource.init(BASE_DIR, "txt");
-
- // call getNextInputStream & assertions
- Set<String> contents = new HashSet<String>();
- for (int i=1; i<=NUM_FILES_IN_DIR; i++) {
- contents.add(Integer.toString(i));
- }
- for (int i=0; i< NUM_FILES_IN_DIR; i++) {
- InputStream inStream = streamSource.getNextInputStream();
- assertNotNull("Unexpected end of input stream list.",inStream);
-
- BufferedReader rd = new BufferedReader(new InputStreamReader(inStream));
- String inputRead = null;
- try {
- inputRead = rd.readLine();
- } catch (IOException ioe) {
- fail("Fail reading from stream at index:"+i + ioe.getMessage());
- }
- assertTrue("File content is incorrect.",contents.contains(inputRead));
- Iterator<String> it = contents.iterator();
- while (it.hasNext()) {
- if (it.next().equals(inputRead)) {
- it.remove();
- break;
- }
- }
- }
-
- // assert that another call to getNextInputStream will return null
- assertNull("Call getNextInputStream after the last file did not return null.",streamSource.getNextInputStream());
- }
-
- /*
- * getCurrentInputStream tests
- */
- public void testGetCurrentInputStream() {
- // write input files & noise files
- writeSimpleFiles(BASE_DIR,"txt",NUM_FILES_IN_DIR);
-
- // init with path to input dir
- streamSource.init(BASE_DIR, "txt");
-
- // call getNextInputStream, getCurrentInputStream & assertions
- for (int i=0; i<= NUM_FILES_IN_DIR; i++) { // test also after-end-of-list
- InputStream inStream1 = streamSource.getNextInputStream();
- InputStream inStream2 = streamSource.getCurrentInputStream();
- assertSame("Incorrect current input stream.",inStream1, inStream2);
- }
- }
-
- /*
- * reset tests
- */
- public void testReset() {
- // write input files & noise files
- writeSimpleFiles(BASE_DIR,"txt",NUM_FILES_IN_DIR);
+ }
- // init with path to input dir
- streamSource.init(BASE_DIR, "txt");
+ @After
+ public void tearDown() throws Exception {
+ FileUtils.deleteDirectory(new File(BASE_DIR));
+ }
- // Get the first input string
- InputStream firstInStream = streamSource.getNextInputStream();
- String firstInput = null;
- assertNotNull("Unexpected end of input stream list.",firstInStream);
+ @Test
+ public void testInitWithSingleFileAndExtension() {
+ // write input file
+ writeSimpleFiles(BASE_DIR, "txt", 1);
- BufferedReader rd1 = new BufferedReader(new InputStreamReader(firstInStream));
- try {
- firstInput = rd1.readLine();
- } catch (IOException ioe) {
- fail("Fail reading from stream at index:0" + ioe.getMessage());
- }
+ // init with path to input file
+ File inFile = new File(BASE_DIR, "1.txt");
+ String inFilePath = inFile.getAbsolutePath();
+ streamSource.init(inFilePath, "txt");
- // call getNextInputStream a few times
- streamSource.getNextInputStream();
+ // assertions
+ assertEquals("Size of filePaths is not correct.", 1, streamSource.getFilePathListSize(), 0);
+ String fn = streamSource.getFilePathAt(0);
+ assertEquals("Incorrect file in filePaths.", inFilePath, fn);
+ }
- // call reset, call next, assert that output is 1 (the first file)
- try {
- streamSource.reset();
- } catch (IOException ioe) {
- fail("Fail resetting stream source." + ioe.getMessage());
- }
+ @Test
+ public void testInitWithSingleFileAndNullExtension() {
+ // write input file
+ writeSimpleFiles(BASE_DIR, "txt", 1);
- InputStream inStream = streamSource.getNextInputStream();
- assertNotNull("Unexpected end of input stream list.",inStream);
+ // init with path to input file
+ File inFile = new File(BASE_DIR, "1.txt");
+ String inFilePath = inFile.getAbsolutePath();
+ streamSource.init(inFilePath, null);
- BufferedReader rd2 = new BufferedReader(new InputStreamReader(inStream));
- String inputRead = null;
- try {
- inputRead = rd2.readLine();
- } catch (IOException ioe) {
- fail("Fail reading from stream at index:0" + ioe.getMessage());
- }
- assertEquals("File content is incorrect.",firstInput,inputRead);
- }
-
- private void writeSimpleFiles(String path, String ext, int numOfFiles) {
- // Create folder
- File folder = new File(path);
- if (!folder.exists()) {
- try{
- folder.mkdir();
- } catch(SecurityException se){
- fail("Failed creating directory:"+path+se);
- }
- }
-
- // Write files
- for (int i=1; i<=numOfFiles; i++) {
- String fn = null;
- if (ext != null) {
- fn = Integer.toString(i) + "."+ ext;
- } else {
- fn = Integer.toString(i);
- }
-
- try {
- FileWriter fwr = new FileWriter(new File(path,fn));
- fwr.write(Integer.toString(i));
- fwr.close();
- } catch (IOException ioe) {
- fail("Fail writing to input file: "+ fn + " in directory: " + path + ioe.getMessage());
- }
- }
- }
+ // assertions
+ assertEquals("Size of filePaths is not correct.", 1, streamSource.getFilePathListSize(), 0);
+ String fn = streamSource.getFilePathAt(0);
+ assertEquals("Incorrect file in filePaths.", inFilePath, fn);
+ }
+
+ @Test
+ public void testInitWithFolderAndExtension() {
+ // write input file
+ writeSimpleFiles(BASE_DIR, null, NUM_NOISE_FILES_IN_DIR);
+ writeSimpleFiles(BASE_DIR, "txt", NUM_FILES_IN_DIR);
+
+ // init with path to input dir
+ File inDir = new File(BASE_DIR);
+ String inDirPath = inDir.getAbsolutePath();
+ streamSource.init(inDirPath, "txt");
+
+ // assertions
+ assertEquals("Size of filePaths is not correct.", NUM_FILES_IN_DIR, streamSource.getFilePathListSize(), 0);
+ Set<String> filenames = new HashSet<String>();
+ for (int i = 1; i <= NUM_FILES_IN_DIR; i++) {
+ String expectedFn = (new File(inDirPath, Integer.toString(i) + ".txt")).getAbsolutePath();
+ filenames.add(expectedFn);
+ }
+ for (int i = 0; i < NUM_FILES_IN_DIR; i++) {
+ String fn = streamSource.getFilePathAt(i);
+ assertTrue("Incorrect file in filePaths:" + fn, filenames.contains(fn));
+ }
+ }
+
+ @Test
+ public void testInitWithFolderAndNullExtension() {
+ // write input file
+ writeSimpleFiles(BASE_DIR, null, NUM_FILES_IN_DIR);
+
+ // init with path to input dir
+ File inDir = new File(BASE_DIR);
+ String inDirPath = inDir.getAbsolutePath();
+ streamSource.init(inDirPath, null);
+
+ // assertions
+ assertEquals("Size of filePaths is not correct.", NUM_FILES_IN_DIR, streamSource.getFilePathListSize(), 0);
+ Set<String> filenames = new HashSet<String>();
+ for (int i = 1; i <= NUM_FILES_IN_DIR; i++) {
+ String expectedFn = (new File(inDirPath, Integer.toString(i))).getAbsolutePath();
+ filenames.add(expectedFn);
+ }
+ for (int i = 0; i < NUM_FILES_IN_DIR; i++) {
+ String fn = streamSource.getFilePathAt(i);
+ assertTrue("Incorrect file in filePaths:" + fn, filenames.contains(fn));
+ }
+ }
+
+ /*
+ * getNextInputStream tests
+ */
+ @Test
+ public void testGetNextInputStream() {
+ // write input files & noise files
+ writeSimpleFiles(BASE_DIR, "txt", NUM_FILES_IN_DIR);
+
+ // init with path to input dir
+ streamSource.init(BASE_DIR, "txt");
+
+ // call getNextInputStream & assertions
+ Set<String> contents = new HashSet<String>();
+ for (int i = 1; i <= NUM_FILES_IN_DIR; i++) {
+ contents.add(Integer.toString(i));
+ }
+ for (int i = 0; i < NUM_FILES_IN_DIR; i++) {
+ InputStream inStream = streamSource.getNextInputStream();
+ assertNotNull("Unexpected end of input stream list.", inStream);
+
+ BufferedReader rd = new BufferedReader(new InputStreamReader(inStream));
+ String inputRead = null;
+ try {
+ inputRead = rd.readLine();
+ } catch (IOException ioe) {
+ fail("Fail reading from stream at index:" + i + ioe.getMessage());
+ }
+ assertTrue("File content is incorrect.", contents.contains(inputRead));
+ Iterator<String> it = contents.iterator();
+ while (it.hasNext()) {
+ if (it.next().equals(inputRead)) {
+ it.remove();
+ break;
+ }
+ }
+ }
+
+ // assert that another call to getNextInputStream will return null
+ assertNull("Call getNextInputStream after the last file did not return null.", streamSource.getNextInputStream());
+ }
+
+ /*
+ * getCurrentInputStream tests
+ */
+ public void testGetCurrentInputStream() {
+ // write input files & noise files
+ writeSimpleFiles(BASE_DIR, "txt", NUM_FILES_IN_DIR);
+
+ // init with path to input dir
+ streamSource.init(BASE_DIR, "txt");
+
+ // call getNextInputStream, getCurrentInputStream & assertions
+ for (int i = 0; i <= NUM_FILES_IN_DIR; i++) { // test also after-end-of-list
+ InputStream inStream1 = streamSource.getNextInputStream();
+ InputStream inStream2 = streamSource.getCurrentInputStream();
+ assertSame("Incorrect current input stream.", inStream1, inStream2);
+ }
+ }
+
+ /*
+ * reset tests
+ */
+ public void testReset() {
+ // write input files & noise files
+ writeSimpleFiles(BASE_DIR, "txt", NUM_FILES_IN_DIR);
+
+ // init with path to input dir
+ streamSource.init(BASE_DIR, "txt");
+
+ // Get the first input string
+ InputStream firstInStream = streamSource.getNextInputStream();
+ String firstInput = null;
+ assertNotNull("Unexpected end of input stream list.", firstInStream);
+
+ BufferedReader rd1 = new BufferedReader(new InputStreamReader(firstInStream));
+ try {
+ firstInput = rd1.readLine();
+ } catch (IOException ioe) {
+ fail("Fail reading from stream at index:0" + ioe.getMessage());
+ }
+
+ // call getNextInputStream a few times
+ streamSource.getNextInputStream();
+
+ // call reset, call next, assert that output is 1 (the first file)
+ try {
+ streamSource.reset();
+ } catch (IOException ioe) {
+ fail("Fail resetting stream source." + ioe.getMessage());
+ }
+
+ InputStream inStream = streamSource.getNextInputStream();
+ assertNotNull("Unexpected end of input stream list.", inStream);
+
+ BufferedReader rd2 = new BufferedReader(new InputStreamReader(inStream));
+ String inputRead = null;
+ try {
+ inputRead = rd2.readLine();
+ } catch (IOException ioe) {
+ fail("Fail reading from stream at index:0" + ioe.getMessage());
+ }
+ assertEquals("File content is incorrect.", firstInput, inputRead);
+ }
+
+ private void writeSimpleFiles(String path, String ext, int numOfFiles) {
+ // Create folder
+ File folder = new File(path);
+ if (!folder.exists()) {
+ try {
+ folder.mkdir();
+ } catch (SecurityException se) {
+ fail("Failed creating directory:" + path + se);
+ }
+ }
+
+ // Write files
+ for (int i = 1; i <= numOfFiles; i++) {
+ String fn = null;
+ if (ext != null) {
+ fn = Integer.toString(i) + "." + ext;
+ } else {
+ fn = Integer.toString(i);
+ }
+
+ try {
+ FileWriter fwr = new FileWriter(new File(path, fn));
+ fwr.write(Integer.toString(i));
+ fwr.close();
+ } catch (IOException ioe) {
+ fail("Fail writing to input file: " + fn + " in directory: " + path + ioe.getMessage());
+ }
+ }
+ }
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/ArffLoader.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/ArffLoader.java
index 9476ec0..812c228 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/ArffLoader.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/ArffLoader.java
@@ -30,335 +30,340 @@
import java.util.logging.Logger;
/**
- *
+ *
* @author abifet
*/
public class ArffLoader implements Serializable {
- protected InstanceInformation instanceInformation;
+ protected InstanceInformation instanceInformation;
- transient protected StreamTokenizer streamTokenizer;
+ transient protected StreamTokenizer streamTokenizer;
- protected Reader reader;
+ protected Reader reader;
- protected int size;
+ protected int size;
- protected int classAttribute;
+ protected int classAttribute;
- public ArffLoader() {
+ public ArffLoader() {
+ }
+
+ public ArffLoader(Reader reader, int size, int classAttribute) {
+ this.reader = reader;
+ this.size = size;
+ this.classAttribute = classAttribute;
+ initStreamTokenizer(reader);
+ }
+
+ public InstanceInformation getStructure() {
+ return this.instanceInformation;
+ }
+
+ public Instance readInstance(Reader reader) {
+ if (streamTokenizer == null) {
+ initStreamTokenizer(reader);
+ }
+ while (streamTokenizer.ttype == StreamTokenizer.TT_EOL) {
+ try {
+ streamTokenizer.nextToken();
+ } catch (IOException ex) {
+ Logger.getLogger(ArffLoader.class.getName()).log(Level.SEVERE, null, ex);
+ }
+ }
+ if (streamTokenizer.ttype == '{') {
+ return readInstanceSparse();
+ // return readDenseInstanceSparse();
+ } else {
+ return readInstanceDense();
}
- public ArffLoader(Reader reader, int size, int classAttribute) {
- this.reader = reader;
- this.size = size;
- this.classAttribute = classAttribute;
- initStreamTokenizer(reader);
- }
+ }
- public InstanceInformation getStructure() {
- return this.instanceInformation;
- }
+ public Instance readInstanceDense() {
+ Instance instance = new DenseInstance(this.instanceInformation.numAttributes() + 1);
+ // System.out.println(this.instanceInformation.numAttributes());
+ int numAttribute = 0;
+ try {
+ while (numAttribute == 0 && streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
+ // For each line
+ while (streamTokenizer.ttype != StreamTokenizer.TT_EOL
+ && streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
+ // For each item
+ if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
+ // System.out.println(streamTokenizer.nval + "Num ");
+ this.setValue(instance, numAttribute, streamTokenizer.nval, true);
+ numAttribute++;
- public Instance readInstance(Reader reader) {
- if (streamTokenizer == null) {
- initStreamTokenizer(reader);
- }
- while (streamTokenizer.ttype == StreamTokenizer.TT_EOL) {
- try {
- streamTokenizer.nextToken();
- } catch (IOException ex) {
- Logger.getLogger(ArffLoader.class.getName()).log(Level.SEVERE, null, ex);
- }
- }
- if (streamTokenizer.ttype == '{') {
- return readInstanceSparse();
- // return readDenseInstanceSparse();
- } else {
- return readInstanceDense();
- }
-
- }
-
- public Instance readInstanceDense() {
- Instance instance = new DenseInstance(this.instanceInformation.numAttributes() + 1);
- //System.out.println(this.instanceInformation.numAttributes());
- int numAttribute = 0;
- try {
- while (numAttribute == 0 && streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
- //For each line
- while (streamTokenizer.ttype != StreamTokenizer.TT_EOL
- && streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
- //For each item
- if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
- //System.out.println(streamTokenizer.nval + "Num ");
- this.setValue(instance, numAttribute, streamTokenizer.nval, true);
- numAttribute++;
-
- } else if (streamTokenizer.sval != null && (streamTokenizer.ttype == StreamTokenizer.TT_WORD
- || streamTokenizer.ttype == 34)) {
- //System.out.println(streamTokenizer.sval + "Str");
- boolean isNumeric = attributes.get(numAttribute).isNumeric();
- double value;
- if ("?".equals(streamTokenizer.sval)) {
- value = Double.NaN; //Utils.missingValue();
- } else if (isNumeric == true) {
- value = Double.valueOf(streamTokenizer.sval).doubleValue();
- } else {
- value = this.instanceInformation.attribute(numAttribute).indexOfValue(streamTokenizer.sval);
- }
-
- this.setValue(instance, numAttribute, value, isNumeric);
- numAttribute++;
- }
- streamTokenizer.nextToken();
- }
- streamTokenizer.nextToken();
- //System.out.println("EOL");
+ } else if (streamTokenizer.sval != null && (streamTokenizer.ttype == StreamTokenizer.TT_WORD
+ || streamTokenizer.ttype == 34)) {
+ // System.out.println(streamTokenizer.sval + "Str");
+ boolean isNumeric = attributes.get(numAttribute).isNumeric();
+ double value;
+ if ("?".equals(streamTokenizer.sval)) {
+ value = Double.NaN; // Utils.missingValue();
+ } else if (isNumeric == true) {
+ value = Double.valueOf(streamTokenizer.sval).doubleValue();
+ } else {
+ value = this.instanceInformation.attribute(numAttribute).indexOfValue(streamTokenizer.sval);
}
-
- } catch (IOException ex) {
- Logger.getLogger(ArffLoader.class.getName()).log(Level.SEVERE, null, ex);
+ this.setValue(instance, numAttribute, value, isNumeric);
+ numAttribute++;
+ }
+ streamTokenizer.nextToken();
}
- return (numAttribute > 0) ? instance : null;
+ streamTokenizer.nextToken();
+ // System.out.println("EOL");
+ }
+
+ } catch (IOException ex) {
+ Logger.getLogger(ArffLoader.class.getName()).log(Level.SEVERE, null, ex);
}
+ return (numAttribute > 0) ? instance : null;
+ }
- private void setValue(Instance instance, int numAttribute, double value, boolean isNumber) {
- double valueAttribute;
- if (isNumber && this.instanceInformation.attribute(numAttribute).isNominal) {
- valueAttribute = this.instanceInformation.attribute(numAttribute).indexOfValue(Double.toString(value));
- //System.out.println(value +"/"+valueAttribute+" ");
+ private void setValue(Instance instance, int numAttribute, double value, boolean isNumber) {
+ double valueAttribute;
+ if (isNumber && this.instanceInformation.attribute(numAttribute).isNominal) {
+ valueAttribute = this.instanceInformation.attribute(numAttribute).indexOfValue(Double.toString(value));
+ // System.out.println(value +"/"+valueAttribute+" ");
- } else {
- valueAttribute = value;
- //System.out.println(value +"/"+valueAttribute+" ");
- }
- if (this.instanceInformation.classIndex() == numAttribute) {
- instance.setClassValue(valueAttribute);
- //System.out.println(value +"<"+this.instanceInformation.classIndex()+">");
- } else {
- instance.setValue(numAttribute, valueAttribute);
- }
+ } else {
+ valueAttribute = value;
+ // System.out.println(value +"/"+valueAttribute+" ");
}
+ if (this.instanceInformation.classIndex() == numAttribute) {
+ instance.setClassValue(valueAttribute);
+ // System.out.println(value
+ // +"<"+this.instanceInformation.classIndex()+">");
+ } else {
+ instance.setValue(numAttribute, valueAttribute);
+ }
+ }
- private Instance readInstanceSparse() {
- //Return a Sparse Instance
- Instance instance = new SparseInstance(1.0, null); //(this.instanceInformation.numAttributes() + 1);
- //System.out.println(this.instanceInformation.numAttributes());
- int numAttribute;
- ArrayList<Double> attributeValues = new ArrayList<Double>();
- List<Integer> indexValues = new ArrayList<Integer>();
- try {
- //while (streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
- streamTokenizer.nextToken(); // Remove the '{' char
- //For each line
- while (streamTokenizer.ttype != StreamTokenizer.TT_EOL
- && streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
- while (streamTokenizer.ttype != '}') {
- //For each item
- //streamTokenizer.nextToken();
- //while (streamTokenizer.ttype != '}'){
- //System.out.println(streamTokenizer.nval +"-"+ streamTokenizer.sval);
- //numAttribute = (int) streamTokenizer.nval;
- if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
- numAttribute = (int) streamTokenizer.nval;
- } else {
- numAttribute = Integer.parseInt(streamTokenizer.sval);
- }
- streamTokenizer.nextToken();
+ private Instance readInstanceSparse() {
+ // Return a Sparse Instance
+ Instance instance = new SparseInstance(1.0, null); // (this.instanceInformation.numAttributes()
+ // + 1);
+ // System.out.println(this.instanceInformation.numAttributes());
+ int numAttribute;
+ ArrayList<Double> attributeValues = new ArrayList<Double>();
+ List<Integer> indexValues = new ArrayList<Integer>();
+ try {
+ // while (streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
+ streamTokenizer.nextToken(); // Remove the '{' char
+ // For each line
+ while (streamTokenizer.ttype != StreamTokenizer.TT_EOL
+ && streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
+ while (streamTokenizer.ttype != '}') {
+ // For each item
+ // streamTokenizer.nextToken();
+ // while (streamTokenizer.ttype != '}'){
+ // System.out.println(streamTokenizer.nval +"-"+
+ // streamTokenizer.sval);
+ // numAttribute = (int) streamTokenizer.nval;
+ if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
+ numAttribute = (int) streamTokenizer.nval;
+ } else {
+ numAttribute = Integer.parseInt(streamTokenizer.sval);
+ }
+ streamTokenizer.nextToken();
- if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
- //System.out.print(streamTokenizer.nval + " ");
- this.setSparseValue(instance, indexValues, attributeValues, numAttribute, streamTokenizer.nval, true);
- //numAttribute++;
+ if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
+ // System.out.print(streamTokenizer.nval + " ");
+ this.setSparseValue(instance, indexValues, attributeValues, numAttribute, streamTokenizer.nval, true);
+ // numAttribute++;
- } else if (streamTokenizer.sval != null && (streamTokenizer.ttype == StreamTokenizer.TT_WORD
- || streamTokenizer.ttype == 34)) {
- //System.out.print(streamTokenizer.sval + "-");
- if (attributes.get(numAttribute).isNumeric()) {
- this.setSparseValue(instance, indexValues, attributeValues, numAttribute, Double.valueOf(streamTokenizer.sval).doubleValue(), true);
- } else {
- this.setSparseValue(instance, indexValues, attributeValues, numAttribute, this.instanceInformation.attribute(numAttribute).indexOfValue(streamTokenizer.sval), false);
- }
- }
- streamTokenizer.nextToken();
- }
- streamTokenizer.nextToken(); //Remove the '}' char
+ } else if (streamTokenizer.sval != null && (streamTokenizer.ttype == StreamTokenizer.TT_WORD
+ || streamTokenizer.ttype == 34)) {
+ // System.out.print(streamTokenizer.sval + "-");
+ if (attributes.get(numAttribute).isNumeric()) {
+ this.setSparseValue(instance, indexValues, attributeValues, numAttribute,
+ Double.valueOf(streamTokenizer.sval).doubleValue(), true);
+ } else {
+ this.setSparseValue(instance, indexValues, attributeValues, numAttribute, this.instanceInformation
+ .attribute(numAttribute).indexOfValue(streamTokenizer.sval), false);
+ }
+ }
+ streamTokenizer.nextToken();
+ }
+ streamTokenizer.nextToken(); // Remove the '}' char
+ }
+ streamTokenizer.nextToken();
+ // System.out.println("EOL");
+ // }
+
+ } catch (IOException ex) {
+ Logger.getLogger(ArffLoader.class.getName()).log(Level.SEVERE, null, ex);
+ }
+ int[] arrayIndexValues = new int[attributeValues.size()];
+ double[] arrayAttributeValues = new double[attributeValues.size()];
+ for (int i = 0; i < arrayIndexValues.length; i++) {
+ arrayIndexValues[i] = indexValues.get(i).intValue();
+ arrayAttributeValues[i] = attributeValues.get(i).doubleValue();
+ }
+ instance.addSparseValues(arrayIndexValues, arrayAttributeValues, this.instanceInformation.numAttributes());
+ return instance;
+
+ }
+
+ private void setSparseValue(Instance instance, List<Integer> indexValues, List<Double> attributeValues,
+ int numAttribute, double value, boolean isNumber) {
+ double valueAttribute;
+ if (isNumber && this.instanceInformation.attribute(numAttribute).isNominal) {
+ valueAttribute = this.instanceInformation.attribute(numAttribute).indexOfValue(Double.toString(value));
+ } else {
+ valueAttribute = value;
+ }
+ if (this.instanceInformation.classIndex() == numAttribute) {
+ instance.setClassValue(valueAttribute);
+ } else {
+ // instance.setValue(numAttribute, valueAttribute);
+ indexValues.add(numAttribute);
+ attributeValues.add(valueAttribute);
+ }
+ // System.out.println(numAttribute+":"+valueAttribute+","+this.instanceInformation.classIndex()+","+value);
+ }
+
+ private Instance readDenseInstanceSparse() {
+ // Returns a dense instance
+ Instance instance = new DenseInstance(this.instanceInformation.numAttributes() + 1);
+ // System.out.println(this.instanceInformation.numAttributes());
+ int numAttribute;
+ try {
+ // while (streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
+ streamTokenizer.nextToken(); // Remove the '{' char
+ // For each line
+ while (streamTokenizer.ttype != StreamTokenizer.TT_EOL
+ && streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
+ while (streamTokenizer.ttype != '}') {
+ // For each item
+ // streamTokenizer.nextToken();
+ // while (streamTokenizer.ttype != '}'){
+ // System.out.print(streamTokenizer.nval+":");
+ numAttribute = (int) streamTokenizer.nval;
+ streamTokenizer.nextToken();
+
+ if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
+ // System.out.print(streamTokenizer.nval + " ");
+ this.setValue(instance, numAttribute, streamTokenizer.nval, true);
+ // numAttribute++;
+
+ } else if (streamTokenizer.sval != null && (streamTokenizer.ttype == StreamTokenizer.TT_WORD
+ || streamTokenizer.ttype == 34)) {
+ // System.out.print(streamTokenizer.sval +
+ // "/"+this.instanceInformation.attribute(numAttribute).indexOfValue(streamTokenizer.sval)+" ");
+ if (attributes.get(numAttribute).isNumeric()) {
+ this.setValue(instance, numAttribute, Double.valueOf(streamTokenizer.sval).doubleValue(), true);
+ } else {
+ this.setValue(instance, numAttribute,
+ this.instanceInformation.attribute(numAttribute).indexOfValue(streamTokenizer.sval), false);
+ // numAttribute++;
+ }
+ }
+ streamTokenizer.nextToken();
+ }
+ streamTokenizer.nextToken(); // Remove the '}' char
+ }
+ streamTokenizer.nextToken();
+ // System.out.println("EOL");
+ // }
+
+ } catch (IOException ex) {
+ Logger.getLogger(ArffLoader.class.getName()).log(Level.SEVERE, null, ex);
+ }
+ return instance;
+ }
+
+ protected List<Attribute> attributes;
+
+ private InstanceInformation getHeader() {
+
+ String relation = "file stream";
+ // System.out.println("RELATION " + relation);
+ attributes = new ArrayList<Attribute>();
+ try {
+ streamTokenizer.nextToken();
+ while (streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
+ // For each line
+ // if (streamTokenizer.ttype == '@') {
+ if (streamTokenizer.ttype == StreamTokenizer.TT_WORD && streamTokenizer.sval.startsWith("@") == true) {
+ // streamTokenizer.nextToken();
+ String token = streamTokenizer.sval.toUpperCase();
+ if (token.startsWith("@RELATION")) {
+ streamTokenizer.nextToken();
+ relation = streamTokenizer.sval;
+ // System.out.println("RELATION " + relation);
+ } else if (token.startsWith("@ATTRIBUTE")) {
+ streamTokenizer.nextToken();
+ String name = streamTokenizer.sval;
+ // System.out.println("* " + name);
+ if (name == null) {
+ name = Double.toString(streamTokenizer.nval);
}
streamTokenizer.nextToken();
- //System.out.println("EOL");
- //}
+ String type = streamTokenizer.sval;
+ // System.out.println("* " + name + ":" + type + " ");
+ if (streamTokenizer.ttype == '{') {
+ streamTokenizer.nextToken();
+ List<String> attributeLabels = new ArrayList<String>();
+ while (streamTokenizer.ttype != '}') {
-
- } catch (IOException ex) {
- Logger.getLogger(ArffLoader.class.getName()).log(Level.SEVERE, null, ex);
- }
- int[] arrayIndexValues = new int[attributeValues.size()];
- double[] arrayAttributeValues = new double[attributeValues.size()];
- for (int i = 0; i < arrayIndexValues.length; i++) {
- arrayIndexValues[i] = indexValues.get(i).intValue();
- arrayAttributeValues[i] = attributeValues.get(i).doubleValue();
- }
- instance.addSparseValues(arrayIndexValues, arrayAttributeValues, this.instanceInformation.numAttributes());
- return instance;
-
- }
-
- private void setSparseValue(Instance instance, List<Integer> indexValues, List<Double> attributeValues, int numAttribute, double value, boolean isNumber) {
- double valueAttribute;
- if (isNumber && this.instanceInformation.attribute(numAttribute).isNominal) {
- valueAttribute = this.instanceInformation.attribute(numAttribute).indexOfValue(Double.toString(value));
- } else {
- valueAttribute = value;
- }
- if (this.instanceInformation.classIndex() == numAttribute) {
- instance.setClassValue(valueAttribute);
- } else {
- //instance.setValue(numAttribute, valueAttribute);
- indexValues.add(numAttribute);
- attributeValues.add(valueAttribute);
- }
- //System.out.println(numAttribute+":"+valueAttribute+","+this.instanceInformation.classIndex()+","+value);
- }
-
- private Instance readDenseInstanceSparse() {
- //Returns a dense instance
- Instance instance = new DenseInstance(this.instanceInformation.numAttributes() + 1);
- //System.out.println(this.instanceInformation.numAttributes());
- int numAttribute;
- try {
- //while (streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
- streamTokenizer.nextToken(); // Remove the '{' char
- //For each line
- while (streamTokenizer.ttype != StreamTokenizer.TT_EOL
- && streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
- while (streamTokenizer.ttype != '}') {
- //For each item
- //streamTokenizer.nextToken();
- //while (streamTokenizer.ttype != '}'){
- //System.out.print(streamTokenizer.nval+":");
- numAttribute = (int) streamTokenizer.nval;
- streamTokenizer.nextToken();
-
- if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) {
- //System.out.print(streamTokenizer.nval + " ");
- this.setValue(instance, numAttribute, streamTokenizer.nval, true);
- //numAttribute++;
-
- } else if (streamTokenizer.sval != null && (streamTokenizer.ttype == StreamTokenizer.TT_WORD
- || streamTokenizer.ttype == 34)) {
- //System.out.print(streamTokenizer.sval + "/"+this.instanceInformation.attribute(numAttribute).indexOfValue(streamTokenizer.sval)+" ");
- if (attributes.get(numAttribute).isNumeric()) {
- this.setValue(instance, numAttribute, Double.valueOf(streamTokenizer.sval).doubleValue(), true);
- } else {
- this.setValue(instance, numAttribute, this.instanceInformation.attribute(numAttribute).indexOfValue(streamTokenizer.sval), false);
- //numAttribute++;
- }
- }
- streamTokenizer.nextToken();
+ if (streamTokenizer.sval != null) {
+ attributeLabels.add(streamTokenizer.sval);
+ // System.out.print(streamTokenizer.sval + ",");
+ } else {
+ attributeLabels.add(Double.toString(streamTokenizer.nval));
+ // System.out.print(streamTokenizer.nval + ",");
}
- streamTokenizer.nextToken(); //Remove the '}' char
- }
- streamTokenizer.nextToken();
- //System.out.println("EOL");
- //}
-
- } catch (IOException ex) {
- Logger.getLogger(ArffLoader.class.getName()).log(Level.SEVERE, null, ex);
- }
- return instance;
- }
-
- protected List<Attribute> attributes;
-
- private InstanceInformation getHeader() {
-
- String relation = "file stream";
- //System.out.println("RELATION " + relation);
- attributes = new ArrayList<Attribute>();
- try {
- streamTokenizer.nextToken();
- while (streamTokenizer.ttype != StreamTokenizer.TT_EOF) {
- //For each line
- //if (streamTokenizer.ttype == '@') {
- if (streamTokenizer.ttype == StreamTokenizer.TT_WORD && streamTokenizer.sval.startsWith("@") == true) {
- //streamTokenizer.nextToken();
- String token = streamTokenizer.sval.toUpperCase();
- if (token.startsWith("@RELATION")) {
- streamTokenizer.nextToken();
- relation = streamTokenizer.sval;
- //System.out.println("RELATION " + relation);
- } else if (token.startsWith("@ATTRIBUTE")) {
- streamTokenizer.nextToken();
- String name = streamTokenizer.sval;
- //System.out.println("* " + name);
- if (name == null) {
- name = Double.toString(streamTokenizer.nval);
- }
- streamTokenizer.nextToken();
- String type = streamTokenizer.sval;
- //System.out.println("* " + name + ":" + type + " ");
- if (streamTokenizer.ttype == '{') {
- streamTokenizer.nextToken();
- List<String> attributeLabels = new ArrayList<String>();
- while (streamTokenizer.ttype != '}') {
-
- if (streamTokenizer.sval != null) {
- attributeLabels.add(streamTokenizer.sval);
- //System.out.print(streamTokenizer.sval + ",");
- } else {
- attributeLabels.add(Double.toString(streamTokenizer.nval));
- //System.out.print(streamTokenizer.nval + ",");
- }
-
- streamTokenizer.nextToken();
- }
- //System.out.println();
- attributes.add(new Attribute(name, attributeLabels));
- } else {
- // Add attribute
- attributes.add(new Attribute(name));
- }
-
- } else if (token.startsWith("@DATA")) {
- //System.out.print("END");
- streamTokenizer.nextToken();
- break;
- }
- }
streamTokenizer.nextToken();
+ }
+ // System.out.println();
+ attributes.add(new Attribute(name, attributeLabels));
+ } else {
+ // Add attribute
+ attributes.add(new Attribute(name));
}
- } catch (IOException ex) {
- Logger.getLogger(ArffLoader.class.getName()).log(Level.SEVERE, null, ex);
+ } else if (token.startsWith("@DATA")) {
+ // System.out.print("END");
+ streamTokenizer.nextToken();
+ break;
+ }
}
- return new InstanceInformation(relation, attributes);
+ streamTokenizer.nextToken();
+ }
+
+ } catch (IOException ex) {
+ Logger.getLogger(ArffLoader.class.getName()).log(Level.SEVERE, null, ex);
}
+ return new InstanceInformation(relation, attributes);
+ }
- private void initStreamTokenizer(Reader reader) {
- BufferedReader br = new BufferedReader(reader);
+ private void initStreamTokenizer(Reader reader) {
+ BufferedReader br = new BufferedReader(reader);
- //Init streamTokenizer
- streamTokenizer = new StreamTokenizer(br);
+ // Init streamTokenizer
+ streamTokenizer = new StreamTokenizer(br);
- streamTokenizer.resetSyntax();
- streamTokenizer.whitespaceChars(0, ' ');
- streamTokenizer.wordChars(' ' + 1, '\u00FF');
- streamTokenizer.whitespaceChars(',', ',');
- streamTokenizer.commentChar('%');
- streamTokenizer.quoteChar('"');
- streamTokenizer.quoteChar('\'');
- streamTokenizer.ordinaryChar('{');
- streamTokenizer.ordinaryChar('}');
- streamTokenizer.eolIsSignificant(true);
+ streamTokenizer.resetSyntax();
+ streamTokenizer.whitespaceChars(0, ' ');
+ streamTokenizer.wordChars(' ' + 1, '\u00FF');
+ streamTokenizer.whitespaceChars(',', ',');
+ streamTokenizer.commentChar('%');
+ streamTokenizer.quoteChar('"');
+ streamTokenizer.quoteChar('\'');
+ streamTokenizer.ordinaryChar('{');
+ streamTokenizer.ordinaryChar('}');
+ streamTokenizer.eolIsSignificant(true);
- this.instanceInformation = this.getHeader();
- if (classAttribute < 0) {
- this.instanceInformation.setClassIndex(this.instanceInformation.numAttributes() - 1);
- //System.out.print(this.instanceInformation.classIndex());
- } else if (classAttribute > 0) {
- this.instanceInformation.setClassIndex(classAttribute - 1);
- }
+ this.instanceInformation = this.getHeader();
+ if (classAttribute < 0) {
+ this.instanceInformation.setClassIndex(this.instanceInformation.numAttributes() - 1);
+ // System.out.print(this.instanceInformation.classIndex());
+ } else if (classAttribute > 0) {
+ this.instanceInformation.setClassIndex(classAttribute - 1);
}
+ }
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Attribute.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Attribute.java
index 8f3873c..68d4808 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Attribute.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Attribute.java
@@ -32,174 +32,174 @@
import java.util.Map;
/**
- *
+ *
* @author abifet
*/
-public class Attribute implements Serializable{
+public class Attribute implements Serializable {
-
- public static final String ARFF_ATTRIBUTE = "@attribute";
- public static final String ARFF_ATTRIBUTE_NUMERIC = "NUMERIC";
-
-
- /**
- *
- */
- protected boolean isNominal;
- /**
- *
- */
- protected boolean isNumeric;
- /**
- *
- */
- protected boolean isDate;
- /**
- *
- */
- protected String name;
- /**
- *
- */
- protected List<String> attributeValues;
+ public static final String ARFF_ATTRIBUTE = "@attribute";
+ public static final String ARFF_ATTRIBUTE_NUMERIC = "NUMERIC";
- /**
+ /**
*
- * @return
*/
- public List<String> getAttributeValues() {
- return attributeValues;
+ protected boolean isNominal;
+ /**
+ *
+ */
+ protected boolean isNumeric;
+ /**
+ *
+ */
+ protected boolean isDate;
+ /**
+ *
+ */
+ protected String name;
+ /**
+ *
+ */
+ protected List<String> attributeValues;
+
+ /**
+ *
+ * @return
+ */
+ public List<String> getAttributeValues() {
+ return attributeValues;
+ }
+
+ /**
+ *
+ */
+ protected int index;
+
+ /**
+ *
+ * @param string
+ */
+ public Attribute(String string) {
+ this.name = string;
+ this.isNumeric = true;
+ }
+
+ /**
+ *
+ * @param attributeName
+ * @param attributeValues
+ */
+ public Attribute(String attributeName, List<String> attributeValues) {
+ this.name = attributeName;
+ this.attributeValues = attributeValues;
+ this.isNominal = true;
+ }
+
+ /**
+ *
+ */
+ public Attribute() {
+ this("");
+ }
+
+ /**
+ *
+ * @return
+ */
+ public boolean isNominal() {
+ return this.isNominal;
+ }
+
+ /**
+ *
+ * @return
+ */
+ public String name() {
+ return this.name;
+ }
+
+ /**
+ *
+ * @param value
+ * @return
+ */
+ public String value(int value) {
+ return attributeValues.get(value);
+ }
+
+ /**
+ *
+ * @return
+ */
+ public boolean isNumeric() {
+ return isNumeric;
+ }
+
+ /**
+ *
+ * @return
+ */
+ public int numValues() {
+ if (isNumeric()) {
+ return 0;
}
- /**
- *
- */
- protected int index;
-
- /**
- *
- * @param string
- */
- public Attribute(String string) {
- this.name = string;
- this.isNumeric = true;
+ else {
+ return attributeValues.size();
}
+ }
- /**
- *
- * @param attributeName
- * @param attributeValues
- */
- public Attribute(String attributeName, List<String> attributeValues) {
- this.name = attributeName;
- this.attributeValues = attributeValues;
- this.isNominal = true;
- }
-
- /**
- *
- */
- public Attribute() {
- this("");
- }
+ /**
+ *
+ * @return
+ */
+ public int index() { // RuleClassifier
+ return this.index;
+ }
- /**
- *
- * @return
- */
- public boolean isNominal() {
- return this.isNominal;
- }
+ String formatDate(double value) {
+ SimpleDateFormat sdf = new SimpleDateFormat();
+ return sdf.format(new Date((long) value));
+ }
- /**
- *
- * @return
- */
- public String name() {
- return this.name;
- }
+ boolean isDate() {
+ return isDate;
+ }
- /**
- *
- * @param value
- * @return
- */
- public String value(int value) {
- return attributeValues.get(value);
- }
+ private Map<String, Integer> valuesStringAttribute;
- /**
- *
- * @return
- */
- public boolean isNumeric() {
- return isNumeric;
- }
+ /**
+ *
+ * @param value
+ * @return
+ */
+ public final int indexOfValue(String value) {
- /**
- *
- * @return
- */
- public int numValues() {
- if (isNumeric()) {
- return 0;
- }
- else {
- return attributeValues.size();
- }
+ if (isNominal() == false) {
+ return -1;
}
-
- /**
- *
- * @return
- */
- public int index() { //RuleClassifier
- return this.index;
+ if (this.valuesStringAttribute == null) {
+ this.valuesStringAttribute = new HashMap<String, Integer>();
+ int count = 0;
+ for (String stringValue : attributeValues) {
+ this.valuesStringAttribute.put(stringValue, count);
+ count++;
+ }
}
-
- String formatDate(double value) {
- SimpleDateFormat sdf = new SimpleDateFormat();
- return sdf.format(new Date((long) value));
+ Integer val = (Integer) this.valuesStringAttribute.get(value);
+ if (val == null) {
+ return -1;
+ } else {
+ return val.intValue();
}
+ }
- boolean isDate() {
- return isDate;
- }
- private Map<String, Integer> valuesStringAttribute;
+ @Override
+ public String toString() {
+ StringBuffer text = new StringBuffer();
- /**
- *
- * @param value
- * @return
- */
- public final int indexOfValue(String value) {
+ text.append(ARFF_ATTRIBUTE).append(" ").append(Utils.quote(this.name)).append(" ");
- if (isNominal() == false) {
- return -1;
- }
- if (this.valuesStringAttribute == null) {
- this.valuesStringAttribute = new HashMap<String, Integer>();
- int count = 0;
- for (String stringValue : attributeValues) {
- this.valuesStringAttribute.put(stringValue, count);
- count++;
- }
- }
- Integer val = (Integer) this.valuesStringAttribute.get(value);
- if (val == null) {
- return -1;
- } else {
- return val.intValue();
- }
- }
+ text.append(ARFF_ATTRIBUTE_NUMERIC);
- @Override
- public String toString() {
- StringBuffer text = new StringBuffer();
-
- text.append(ARFF_ATTRIBUTE).append(" ").append(Utils.quote(this.name)).append(" ");
-
- text.append(ARFF_ATTRIBUTE_NUMERIC);
-
- return text.toString();
- }
+ return text.toString();
+ }
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/DenseInstance.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/DenseInstance.java
index b63f736..cf5f8a4 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/DenseInstance.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/DenseInstance.java
@@ -25,48 +25,50 @@
*/
/**
- *
+ *
* @author abifet
*/
public class DenseInstance extends SingleLabelInstance {
- private static final long serialVersionUID = 280360594027716737L;
+ private static final long serialVersionUID = 280360594027716737L;
- public DenseInstance() {
- // necessary for kryo serializer
- }
-
- public DenseInstance(double weight, double[] res) {
- super(weight,res);
- }
- public DenseInstance(SingleLabelInstance inst) {
- super(inst);
- }
-
- public DenseInstance(Instance inst) {
- super((SingleLabelInstance) inst);
- }
- public DenseInstance(double numberAttributes) {
- super((int) numberAttributes);
- //super(1, new double[(int) numberAttributes-1]);
- //Add missing values
- //for (int i = 0; i < numberAttributes-1; i++) {
- // //this.setValue(i, Double.NaN);
- //}
-
- }
-
- @Override
- public String toString() {
- StringBuffer text = new StringBuffer();
+ public DenseInstance() {
+ // necessary for kryo serializer
+ }
- for (int i = 0; i < this.instanceInformation.numAttributes(); i++) {
- if (i > 0)
- text.append(",");
- text.append(this.value(i));
- }
- text.append(",").append(this.weight());
-
- return text.toString();
+ public DenseInstance(double weight, double[] res) {
+ super(weight, res);
+ }
+
+ public DenseInstance(SingleLabelInstance inst) {
+ super(inst);
+ }
+
+ public DenseInstance(Instance inst) {
+ super((SingleLabelInstance) inst);
+ }
+
+ public DenseInstance(double numberAttributes) {
+ super((int) numberAttributes);
+ // super(1, new double[(int) numberAttributes-1]);
+ // Add missing values
+ // for (int i = 0; i < numberAttributes-1; i++) {
+ // //this.setValue(i, Double.NaN);
+ // }
+
+ }
+
+ @Override
+ public String toString() {
+ StringBuffer text = new StringBuffer();
+
+ for (int i = 0; i < this.instanceInformation.numAttributes(); i++) {
+ if (i > 0)
+ text.append(",");
+ text.append(this.value(i));
}
+ text.append(",").append(this.weight());
+
+ return text.toString();
+ }
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/DenseInstanceData.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/DenseInstanceData.java
index e83a4d9..b92d7f3 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/DenseInstanceData.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/DenseInstanceData.java
@@ -25,73 +25,73 @@
*/
/**
- *
+ *
* @author abifet
*/
-public class DenseInstanceData implements InstanceData{
+public class DenseInstanceData implements InstanceData {
- public DenseInstanceData(double[] array) {
- this.attributeValues = array;
- }
-
- public DenseInstanceData(int length) {
- this.attributeValues = new double[length];
- }
-
- public DenseInstanceData() {
- this(0);
- }
-
- protected double[] attributeValues;
+ public DenseInstanceData(double[] array) {
+ this.attributeValues = array;
+ }
- @Override
- public int numAttributes() {
- return this.attributeValues.length;
- }
+ public DenseInstanceData(int length) {
+ this.attributeValues = new double[length];
+ }
- @Override
- public double value(int indexAttribute) {
- return this.attributeValues[indexAttribute];
- }
+ public DenseInstanceData() {
+ this(0);
+ }
- @Override
- public boolean isMissing(int indexAttribute) {
- return Double.isNaN(this.value(indexAttribute));
- }
+ protected double[] attributeValues;
- @Override
- public int numValues() {
- return numAttributes();
- }
+ @Override
+ public int numAttributes() {
+ return this.attributeValues.length;
+ }
- @Override
- public int index(int indexAttribute) {
- return indexAttribute;
- }
+ @Override
+ public double value(int indexAttribute) {
+ return this.attributeValues[indexAttribute];
+ }
- @Override
- public double valueSparse(int indexAttribute) {
- return value(indexAttribute);
- }
+ @Override
+ public boolean isMissing(int indexAttribute) {
+ return Double.isNaN(this.value(indexAttribute));
+ }
- @Override
- public boolean isMissingSparse(int indexAttribute) {
- return isMissing(indexAttribute);
- }
+ @Override
+ public int numValues() {
+ return numAttributes();
+ }
- /*@Override
- public double value(Attribute attribute) {
- return value(attribute.index());
- }*/
+ @Override
+ public int index(int indexAttribute) {
+ return indexAttribute;
+ }
- @Override
- public double[] toDoubleArray() {
- return attributeValues.clone();
- }
+ @Override
+ public double valueSparse(int indexAttribute) {
+ return value(indexAttribute);
+ }
- @Override
- public void setValue(int attributeIndex, double d) {
- this.attributeValues[attributeIndex] = d;
- }
-
+ @Override
+ public boolean isMissingSparse(int indexAttribute) {
+ return isMissing(indexAttribute);
+ }
+
+ /*
+ * @Override public double value(Attribute attribute) { return
+ * value(attribute.index()); }
+ */
+
+ @Override
+ public double[] toDoubleArray() {
+ return attributeValues.clone();
+ }
+
+ @Override
+ public void setValue(int attributeIndex, double d) {
+ this.attributeValues[attributeIndex] = d;
+ }
+
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Instance.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Instance.java
index 7d8e337..db5caac 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Instance.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Instance.java
@@ -27,48 +27,67 @@
import java.io.Serializable;
/**
- *
+ *
* @author abifet
*/
-public interface Instance extends Serializable{
+public interface Instance extends Serializable {
- double weight();
- void setWeight(double weight);
-
- //Attributes
- Attribute attribute(int instAttIndex);
- void deleteAttributeAt(int i);
- void insertAttributeAt(int i);
- int numAttributes();
- public void addSparseValues(int[] indexValues, double[] attributeValues, int numberAttributes);
-
+ double weight();
- //Values
- int numValues();
- String stringValue(int i);
- double value(int instAttIndex);
- double value(Attribute attribute);
- void setValue(int m_numAttributes, double d);
- boolean isMissing(int instAttIndex);
- int index(int i);
- double valueSparse(int i);
- boolean isMissingSparse(int p1);
- double[] toDoubleArray();
-
- //Class
- Attribute classAttribute();
- int classIndex();
- boolean classIsMissing();
- double classValue();
- int numClasses();
- void setClassValue(double d);
+ void setWeight(double weight);
- Instance copy();
+ // Attributes
+ Attribute attribute(int instAttIndex);
- //Dataset
- void setDataset(Instances dataset);
- Instances dataset();
- String toString();
+ void deleteAttributeAt(int i);
+
+ void insertAttributeAt(int i);
+
+ int numAttributes();
+
+ public void addSparseValues(int[] indexValues, double[] attributeValues, int numberAttributes);
+
+ // Values
+ int numValues();
+
+ String stringValue(int i);
+
+ double value(int instAttIndex);
+
+ double value(Attribute attribute);
+
+ void setValue(int m_numAttributes, double d);
+
+ boolean isMissing(int instAttIndex);
+
+ int index(int i);
+
+ double valueSparse(int i);
+
+ boolean isMissingSparse(int p1);
+
+ double[] toDoubleArray();
+
+ // Class
+ Attribute classAttribute();
+
+ int classIndex();
+
+ boolean classIsMissing();
+
+ double classValue();
+
+ int numClasses();
+
+ void setClassValue(double d);
+
+ Instance copy();
+
+ // Dataset
+ void setDataset(Instances dataset);
+
+ Instances dataset();
+
+ String toString();
}
-
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstanceData.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstanceData.java
index e8492a9..54b02b0 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstanceData.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstanceData.java
@@ -27,29 +27,29 @@
import java.io.Serializable;
/**
- *
+ *
* @author abifet
*/
-public interface InstanceData extends Serializable{
+public interface InstanceData extends Serializable {
- public int numAttributes();
+ public int numAttributes();
- public double value(int instAttIndex);
+ public double value(int instAttIndex);
- public boolean isMissing(int instAttIndex);
+ public boolean isMissing(int instAttIndex);
- public int numValues();
+ public int numValues();
- public int index(int i);
+ public int index(int i);
- public double valueSparse(int i);
+ public double valueSparse(int i);
- public boolean isMissingSparse(int p1);
+ public boolean isMissingSparse(int p1);
- //public double value(Attribute attribute);
+ // public double value(Attribute attribute);
- public double[] toDoubleArray();
+ public double[] toDoubleArray();
- public void setValue(int m_numAttributes, double d);
-
+ public void setValue(int m_numAttributes, double d);
+
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstanceInformation.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstanceInformation.java
index ff22762..b00ae3c 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstanceInformation.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstanceInformation.java
@@ -28,85 +28,81 @@
import java.util.List;
/**
- *
+ *
* @author abifet
*/
-public class InstanceInformation implements Serializable{
-
- //Should we split Instances as a List of Instances, and InformationInstances
-
+public class InstanceInformation implements Serializable {
+
+ // Should we split Instances as a List of Instances, and InformationInstances
+
/** The dataset's name. */
- protected String relationName;
+ protected String relationName;
/** The attribute information. */
protected List<Attribute> attributes;
-
+
protected int classIndex;
-
-
- public InstanceInformation(InstanceInformation chunk) {
- this.relationName = chunk.relationName;
- this.attributes = chunk.attributes;
- this.classIndex = chunk.classIndex;
- }
-
- public InstanceInformation(String st, List<Attribute> v) {
- this.relationName = st;
- this.attributes = v;
- }
-
- public InstanceInformation() {
- this.relationName = null;
- this.attributes = null;
- }
-
-
- //Information Instances
-
- public void setRelationName(String string) {
- this.relationName = string;
- }
+ public InstanceInformation(InstanceInformation chunk) {
+ this.relationName = chunk.relationName;
+ this.attributes = chunk.attributes;
+ this.classIndex = chunk.classIndex;
+ }
- public String getRelationName() {
- return this.relationName;
- }
-
- public int classIndex() {
- return classIndex;
- }
+ public InstanceInformation(String st, List<Attribute> v) {
+ this.relationName = st;
+ this.attributes = v;
+ }
- public void setClassIndex(int classIndex) {
- this.classIndex = classIndex;
- }
-
- public Attribute classAttribute() {
- return this.attribute(this.classIndex());
- }
+ public InstanceInformation() {
+ this.relationName = null;
+ this.attributes = null;
+ }
- public int numAttributes() {
- return this.attributes.size();
- }
+ // Information Instances
- public Attribute attribute(int w) {
- return this.attributes.get(w);
- }
-
- public int numClasses() {
- return this.attributes.get(this.classIndex()).numValues();
- }
-
- public void deleteAttributeAt(Integer integer) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
+ public void setRelationName(String string) {
+ this.relationName = string;
+ }
- public void insertAttributeAt(Attribute attribute, int i) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
+ public String getRelationName() {
+ return this.relationName;
+ }
- public void setAttributes(List<Attribute> v) {
- this.attributes = v;
- }
-
-
+ public int classIndex() {
+ return classIndex;
+ }
+
+ public void setClassIndex(int classIndex) {
+ this.classIndex = classIndex;
+ }
+
+ public Attribute classAttribute() {
+ return this.attribute(this.classIndex());
+ }
+
+ public int numAttributes() {
+ return this.attributes.size();
+ }
+
+ public Attribute attribute(int w) {
+ return this.attributes.get(w);
+ }
+
+ public int numClasses() {
+ return this.attributes.get(this.classIndex()).numValues();
+ }
+
+ public void deleteAttributeAt(Integer integer) {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
+
+ public void insertAttributeAt(Attribute attribute, int i) {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
+
+ public void setAttributes(List<Attribute> v) {
+ this.attributes = v;
+ }
+
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Instances.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Instances.java
index 85b52ec..d3be605 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Instances.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Instances.java
@@ -37,210 +37,208 @@
*/
public class Instances implements Serializable {
- public static final String ARFF_RELATION = "@relation";
- public static final String ARFF_DATA = "@data";
+ public static final String ARFF_RELATION = "@relation";
+ public static final String ARFF_DATA = "@data";
-
- protected InstanceInformation instanceInformation;
- /**
- * The instances.
- */
- protected List<Instance> instances;
+ protected InstanceInformation instanceInformation;
+ /**
+ * The instances.
+ */
+ protected List<Instance> instances;
- transient protected ArffLoader arff;
-
- protected int classAttribute;
+ transient protected ArffLoader arff;
- public Instances(InstancesHeader modelContext) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
+ protected int classAttribute;
- public Instances(Instances chunk) {
- this.instanceInformation = chunk.instanceInformation();
- // this.relationName = chunk.relationName;
- // this.attributes = chunk.attributes;
- this.instances = chunk.instances;
- }
+ public Instances(InstancesHeader modelContext) {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
- public Instances() {
- // this.instanceInformation = chunk.instanceInformation();
- // this.relationName = chunk.relationName;
- // this.attributes = chunk.attributes;
- // this.instances = chunk.instances;
- }
+ public Instances(Instances chunk) {
+ this.instanceInformation = chunk.instanceInformation();
+ // this.relationName = chunk.relationName;
+ // this.attributes = chunk.attributes;
+ this.instances = chunk.instances;
+ }
- public Instances(Reader reader, int size, int classAttribute) {
- this.classAttribute = classAttribute;
- arff = new ArffLoader(reader, 0, classAttribute);
- this.instanceInformation = arff.getStructure();
- this.instances = new ArrayList<>();
- }
+ public Instances() {
+ // this.instanceInformation = chunk.instanceInformation();
+ // this.relationName = chunk.relationName;
+ // this.attributes = chunk.attributes;
+ // this.instances = chunk.instances;
+ }
- public Instances(Instances chunk, int capacity) {
- this(chunk);
- }
+ public Instances(Reader reader, int size, int classAttribute) {
+ this.classAttribute = classAttribute;
+ arff = new ArffLoader(reader, 0, classAttribute);
+ this.instanceInformation = arff.getStructure();
+ this.instances = new ArrayList<>();
+ }
- public Instances(String st, List<Attribute> v, int capacity) {
-
- this.instanceInformation = new InstanceInformation(st, v);
- this.instances = new ArrayList<>();
- }
+ public Instances(Instances chunk, int capacity) {
+ this(chunk);
+ }
- public Instances(Instances chunk, int i, int j) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
+ public Instances(String st, List<Attribute> v, int capacity) {
- public Instances(StringReader st, int v) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
+ this.instanceInformation = new InstanceInformation(st, v);
+ this.instances = new ArrayList<>();
+ }
- // Information Instances
- public void setRelationName(String string) {
- this.instanceInformation.setRelationName(string);
- }
+ public Instances(Instances chunk, int i, int j) {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
- public String getRelationName() {
- return this.instanceInformation.getRelationName();
- }
+ public Instances(StringReader st, int v) {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
- public int classIndex() {
- return this.instanceInformation.classIndex();
- }
+ // Information Instances
+ public void setRelationName(String string) {
+ this.instanceInformation.setRelationName(string);
+ }
- public void setClassIndex(int classIndex) {
- this.instanceInformation.setClassIndex(classIndex);
- }
+ public String getRelationName() {
+ return this.instanceInformation.getRelationName();
+ }
- public Attribute classAttribute() {
- return this.instanceInformation.classAttribute();
- }
+ public int classIndex() {
+ return this.instanceInformation.classIndex();
+ }
- public int numAttributes() {
- return this.instanceInformation.numAttributes();
- }
+ public void setClassIndex(int classIndex) {
+ this.instanceInformation.setClassIndex(classIndex);
+ }
- public Attribute attribute(int w) {
- return this.instanceInformation.attribute(w);
- }
+ public Attribute classAttribute() {
+ return this.instanceInformation.classAttribute();
+ }
- public int numClasses() {
- return this.instanceInformation.numClasses();
- }
+ public int numAttributes() {
+ return this.instanceInformation.numAttributes();
+ }
- public void deleteAttributeAt(Integer integer) {
- this.instanceInformation.deleteAttributeAt(integer);
- }
+ public Attribute attribute(int w) {
+ return this.instanceInformation.attribute(w);
+ }
- public void insertAttributeAt(Attribute attribute, int i) {
- this.instanceInformation.insertAttributeAt(attribute, i);
- }
+ public int numClasses() {
+ return this.instanceInformation.numClasses();
+ }
- // List of Instances
- public Instance instance(int num) {
- return this.instances.get(num);
- }
+ public void deleteAttributeAt(Integer integer) {
+ this.instanceInformation.deleteAttributeAt(integer);
+ }
- public int numInstances() {
- return this.instances.size();
- }
+ public void insertAttributeAt(Attribute attribute, int i) {
+ this.instanceInformation.insertAttributeAt(attribute, i);
+ }
- public void add(Instance inst) {
- this.instances.add(inst.copy());
- }
+ // List of Instances
+ public Instance instance(int num) {
+ return this.instances.get(num);
+ }
- public void randomize(Random random) {
- for (int j = numInstances() - 1; j > 0; j--) {
- swap(j, random.nextInt(j + 1));
- }
- }
+ public int numInstances() {
+ return this.instances.size();
+ }
- public void stratify(int numFolds) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
+ public void add(Instance inst) {
+ this.instances.add(inst.copy());
+ }
- public Instances trainCV(int numFolds, int n, Random random) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
+ public void randomize(Random random) {
+ for (int j = numInstances() - 1; j > 0; j--) {
+ swap(j, random.nextInt(j + 1));
+ }
+ }
- public Instances testCV(int numFolds, int n) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
+ public void stratify(int numFolds) {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
- /*
- * public Instances dataset() { throw new
- * UnsupportedOperationException("Not yet implemented"); }
- */
- public double meanOrMode(int j) {
- throw new UnsupportedOperationException("Not yet implemented"); // CobWeb
- }
+ public Instances trainCV(int numFolds, int n, Random random) {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
- public boolean readInstance(Reader fileReader) {
+ public Instances testCV(int numFolds, int n) {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
- // ArffReader arff = new ArffReader(reader, this, m_Lines, 1);
- if (arff == null) {
- arff = new ArffLoader(fileReader,0,this.classAttribute);
- }
- Instance inst = arff.readInstance(fileReader);
- if (inst != null) {
- inst.setDataset(this);
- add(inst);
- return true;
- } else {
- return false;
- }
- }
+ /*
+ * public Instances dataset() { throw new
+ * UnsupportedOperationException("Not yet implemented"); }
+ */
+ public double meanOrMode(int j) {
+ throw new UnsupportedOperationException("Not yet implemented"); // CobWeb
+ }
- public void delete() {
- this.instances = new ArrayList<>();
- }
+ public boolean readInstance(Reader fileReader) {
- public void swap(int i, int j) {
- Instance in = instances.get(i);
- instances.set(i, instances.get(j));
- instances.set(j, in);
- }
+ // ArffReader arff = new ArffReader(reader, this, m_Lines, 1);
+ if (arff == null) {
+ arff = new ArffLoader(fileReader, 0, this.classAttribute);
+ }
+ Instance inst = arff.readInstance(fileReader);
+ if (inst != null) {
+ inst.setDataset(this);
+ add(inst);
+ return true;
+ } else {
+ return false;
+ }
+ }
- private InstanceInformation instanceInformation() {
- return this.instanceInformation;
- }
+ public void delete() {
+ this.instances = new ArrayList<>();
+ }
- public Attribute attribute(String name) {
+ public void swap(int i, int j) {
+ Instance in = instances.get(i);
+ instances.set(i, instances.get(j));
+ instances.set(j, in);
+ }
- for (int i = 0; i < numAttributes(); i++) {
- if (attribute(i).name().equals(name)) {
- return attribute(i);
- }
- }
- return null;
- }
+ private InstanceInformation instanceInformation() {
+ return this.instanceInformation;
+ }
+ public Attribute attribute(String name) {
- @Override
- public String toString() {
- StringBuilder text = new StringBuilder();
+ for (int i = 0; i < numAttributes(); i++) {
+ if (attribute(i).name().equals(name)) {
+ return attribute(i);
+ }
+ }
+ return null;
+ }
- for (int i = 0; i < numInstances(); i++) {
- text.append(instance(i).toString());
- if (i < numInstances() - 1) {
- text.append('\n');
- }
- }
- return text.toString();
- }
+ @Override
+ public String toString() {
+ StringBuilder text = new StringBuilder();
- // toString() with header
- public String toStringArff() {
- StringBuilder text = new StringBuilder();
+ for (int i = 0; i < numInstances(); i++) {
+ text.append(instance(i).toString());
+ if (i < numInstances() - 1) {
+ text.append('\n');
+ }
+ }
+ return text.toString();
+ }
- text.append(ARFF_RELATION).append(" ")
- .append(Utils.quote(getRelationName())).append("\n\n");
- for (int i = 0; i < numAttributes(); i++) {
- text.append(attribute(i).toString()).append("\n");
- }
- text.append("\n").append(ARFF_DATA).append("\n");
+ // toString() with header
+ public String toStringArff() {
+ StringBuilder text = new StringBuilder();
- text.append(toString());
- return text.toString();
+ text.append(ARFF_RELATION).append(" ")
+ .append(Utils.quote(getRelationName())).append("\n\n");
+ for (int i = 0; i < numAttributes(); i++) {
+ text.append(attribute(i).toString()).append("\n");
+ }
+ text.append("\n").append(ARFF_DATA).append("\n");
- }
+ text.append(toString());
+ return text.toString();
+
+ }
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstancesHeader.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstancesHeader.java
index 1ffa6e7..095e8d7 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstancesHeader.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/InstancesHeader.java
@@ -20,112 +20,105 @@
* #L%
*/
-
/**
- * Class for storing the header or context of a data stream. It allows to know the number of attributes and classes.
- *
+ * Class for storing the header or context of a data stream. It allows to know
+ * the number of attributes and classes.
+ *
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 7 $
*/
public class InstancesHeader extends Instances {
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- public InstancesHeader(Instances i) {
- super(i, 0);
+ public InstancesHeader(Instances i) {
+ super(i, 0);
+ }
+
+ public InstancesHeader() {
+ super();
+ }
+
+ /*
+ * @Override public boolean add(Instance i) { throw new
+ * UnsupportedOperationException(); }
+ *
+ * @Override public boolean readInstance(Reader r) throws IOException { throw
+ * new UnsupportedOperationException(); }
+ */
+
+ public static String getClassNameString(InstancesHeader context) {
+ if (context == null) {
+ return "[class]";
}
+ return "[class:" + context.classAttribute().name() + "]";
+ }
- public InstancesHeader() {
- super();
+ public static String getClassLabelString(InstancesHeader context,
+ int classLabelIndex) {
+ if ((context == null) || (classLabelIndex >= context.numClasses())) {
+ return "<class " + (classLabelIndex + 1) + ">";
}
-
- /* @Override
- public boolean add(Instance i) {
- throw new UnsupportedOperationException();
+ return "<class " + (classLabelIndex + 1) + ":"
+ + context.classAttribute().value(classLabelIndex) + ">";
+ }
+
+ // is impervious to class index changes - attIndex is true attribute index
+ // regardless of class position
+ public static String getAttributeNameString(InstancesHeader context,
+ int attIndex) {
+ if ((context == null) || (attIndex >= context.numAttributes())) {
+ return "[att " + (attIndex + 1) + "]";
}
+ int instAttIndex = attIndex < context.classIndex() ? attIndex
+ : attIndex + 1;
+ return "[att " + (attIndex + 1) + ":"
+ + context.attribute(instAttIndex).name() + "]";
+ }
- @Override
- public boolean readInstance(Reader r) throws IOException {
- throw new UnsupportedOperationException();
- }*/
+ // is impervious to class index changes - attIndex is true attribute index
+ // regardless of class position
+ public static String getNominalValueString(InstancesHeader context,
+ int attIndex, int valIndex) {
+ if (context != null) {
+ int instAttIndex = attIndex < context.classIndex() ? attIndex
+ : attIndex + 1;
+ if ((instAttIndex < context.numAttributes())
+ && (valIndex < context.attribute(instAttIndex).numValues())) {
+ return "{val " + (valIndex + 1) + ":"
+ + context.attribute(instAttIndex).value(valIndex) + "}";
+ }
+ }
+ return "{val " + (valIndex + 1) + "}";
+ }
- public static String getClassNameString(InstancesHeader context) {
- if (context == null) {
- return "[class]";
+ // is impervious to class index changes - attIndex is true attribute index
+ // regardless of class position
+ public static String getNumericValueString(InstancesHeader context,
+ int attIndex, double value) {
+ if (context != null) {
+ int instAttIndex = attIndex < context.classIndex() ? attIndex
+ : attIndex + 1;
+ if (instAttIndex < context.numAttributes()) {
+ if (context.attribute(instAttIndex).isDate()) {
+ return context.attribute(instAttIndex).formatDate(value);
}
- return "[class:" + context.classAttribute().name() + "]";
+ }
}
+ return Double.toString(value);
+ }
- public static String getClassLabelString(InstancesHeader context,
- int classLabelIndex) {
- if ((context == null) || (classLabelIndex >= context.numClasses())) {
- return "<class " + (classLabelIndex + 1) + ">";
- }
- return "<class " + (classLabelIndex + 1) + ":"
- + context.classAttribute().value(classLabelIndex) + ">";
- }
-
- // is impervious to class index changes - attIndex is true attribute index
- // regardless of class position
- public static String getAttributeNameString(InstancesHeader context,
- int attIndex) {
- if ((context == null) || (attIndex >= context.numAttributes())) {
- return "[att " + (attIndex + 1) + "]";
- }
- int instAttIndex = attIndex < context.classIndex() ? attIndex
- : attIndex + 1;
- return "[att " + (attIndex + 1) + ":"
- + context.attribute(instAttIndex).name() + "]";
- }
-
- // is impervious to class index changes - attIndex is true attribute index
- // regardless of class position
- public static String getNominalValueString(InstancesHeader context,
- int attIndex, int valIndex) {
- if (context != null) {
- int instAttIndex = attIndex < context.classIndex() ? attIndex
- : attIndex + 1;
- if ((instAttIndex < context.numAttributes())
- && (valIndex < context.attribute(instAttIndex).numValues())) {
- return "{val " + (valIndex + 1) + ":"
- + context.attribute(instAttIndex).value(valIndex) + "}";
- }
- }
- return "{val " + (valIndex + 1) + "}";
- }
-
- // is impervious to class index changes - attIndex is true attribute index
- // regardless of class position
- public static String getNumericValueString(InstancesHeader context,
- int attIndex, double value) {
- if (context != null) {
- int instAttIndex = attIndex < context.classIndex() ? attIndex
- : attIndex + 1;
- if (instAttIndex < context.numAttributes()) {
- if (context.attribute(instAttIndex).isDate()) {
- return context.attribute(instAttIndex).formatDate(value);
- }
- }
- }
- return Double.toString(value);
- }
-
-
- //add autom.
- /* public int classIndex() {
- throw new UnsupportedOperationException("Not yet implemented");
- }
-
- public int numAttributes() {
- throw new UnsupportedOperationException("Not yet implemented");
- }
-
- @Override
- public Attribute attribute(int nPos) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
-
- public int numClasses() {
- return 0;
- }*/
+ // add autom.
+ /*
+ * public int classIndex() { throw new
+ * UnsupportedOperationException("Not yet implemented"); }
+ *
+ * public int numAttributes() { throw new
+ * UnsupportedOperationException("Not yet implemented"); }
+ *
+ * @Override public Attribute attribute(int nPos) { throw new
+ * UnsupportedOperationException("Not yet implemented"); }
+ *
+ * public int numClasses() { return 0; }
+ */
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SingleClassInstanceData.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SingleClassInstanceData.java
index 878c338..b3007b0 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SingleClassInstanceData.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SingleClassInstanceData.java
@@ -25,62 +25,62 @@
*/
/**
- *
+ *
* @author abifet
*/
public class SingleClassInstanceData implements InstanceData {
- protected double classValue;
-
- @Override
- public int numAttributes() {
- return 1;
- }
+ protected double classValue;
- @Override
- public double value(int instAttIndex) {
- return classValue;
- }
+ @Override
+ public int numAttributes() {
+ return 1;
+ }
- @Override
- public boolean isMissing(int indexAttribute) {
- return Double.isNaN(this.value(indexAttribute));
- }
+ @Override
+ public double value(int instAttIndex) {
+ return classValue;
+ }
- @Override
- public int numValues() {
- return 1;
- }
+ @Override
+ public boolean isMissing(int indexAttribute) {
+ return Double.isNaN(this.value(indexAttribute));
+ }
- @Override
- public int index(int i) {
- return 0;
- }
+ @Override
+ public int numValues() {
+ return 1;
+ }
- @Override
- public double valueSparse(int i) {
- return value(i);
- }
+ @Override
+ public int index(int i) {
+ return 0;
+ }
- @Override
- public boolean isMissingSparse(int indexAttribute) {
- return Double.isNaN(this.value(indexAttribute));
- }
+ @Override
+ public double valueSparse(int i) {
+ return value(i);
+ }
- /*@Override
- public double value(Attribute attribute) {
- return this.classValue;
- }*/
+ @Override
+ public boolean isMissingSparse(int indexAttribute) {
+ return Double.isNaN(this.value(indexAttribute));
+ }
- @Override
- public double[] toDoubleArray() {
- double[] array = {this.classValue};
- return array;
- }
+ /*
+ * @Override public double value(Attribute attribute) { return
+ * this.classValue; }
+ */
- @Override
- public void setValue(int m_numAttributes, double d) {
- this.classValue = d;
- }
-
+ @Override
+ public double[] toDoubleArray() {
+ double[] array = { this.classValue };
+ return array;
+ }
+
+ @Override
+ public void setValue(int m_numAttributes, double d) {
+ this.classValue = d;
+ }
+
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SingleLabelInstance.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SingleLabelInstance.java
index 81b2818..0cf2bb2 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SingleLabelInstance.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SingleLabelInstance.java
@@ -32,230 +32,229 @@
public class SingleLabelInstance implements Instance {
- protected double weight;
+ protected double weight;
- protected InstanceData instanceData;
+ protected InstanceData instanceData;
- protected InstanceData classData;
+ protected InstanceData classData;
- // Fast implementation without using Objects
- // protected double[] attributeValues;
- // protected double classValue;
+ // Fast implementation without using Objects
+ // protected double[] attributeValues;
+ // protected double classValue;
- protected InstancesHeader instanceInformation;
+ protected InstancesHeader instanceInformation;
- public SingleLabelInstance() {
- // necessary for kryo serializer
- }
+ public SingleLabelInstance() {
+ // necessary for kryo serializer
+ }
- public SingleLabelInstance(SingleLabelInstance inst) {
- this.weight = inst.weight;
- this.instanceData = inst.instanceData; // copy
- this.classData = inst.classData; // copy
- // this.classValue = inst.classValue;
- // this.attributeValues = inst.attributeValues;
- this.instanceInformation = inst.instanceInformation;
- }
+ public SingleLabelInstance(SingleLabelInstance inst) {
+ this.weight = inst.weight;
+ this.instanceData = inst.instanceData; // copy
+ this.classData = inst.classData; // copy
+ // this.classValue = inst.classValue;
+ // this.attributeValues = inst.attributeValues;
+ this.instanceInformation = inst.instanceInformation;
+ }
- // Dense
- public SingleLabelInstance(double weight, double[] res) {
- this.weight = weight;
- this.instanceData = new DenseInstanceData(res);
- //this.attributeValues = res;
- this.classData = new SingleClassInstanceData();
- // this.classValue = Double.NaN;
-
-
- }
+ // Dense
+ public SingleLabelInstance(double weight, double[] res) {
+ this.weight = weight;
+ this.instanceData = new DenseInstanceData(res);
+ // this.attributeValues = res;
+ this.classData = new SingleClassInstanceData();
+ // this.classValue = Double.NaN;
- // Sparse
- public SingleLabelInstance(double weight, double[] attributeValues,
- int[] indexValues, int numberAttributes) {
- this.weight = weight;
- this.instanceData = new SparseInstanceData(attributeValues,
- indexValues, numberAttributes); // ???
- this.classData = new SingleClassInstanceData();
- // this.classValue = Double.NaN;
- //this.instanceInformation = new InstancesHeader();
-
- }
+ }
- public SingleLabelInstance(double weight, InstanceData instanceData) {
- this.weight = weight;
- this.instanceData = instanceData; // ???
- // this.classValue = Double.NaN;
- this.classData = new SingleClassInstanceData();
- //this.instanceInformation = new InstancesHeader();
- }
+ // Sparse
+ public SingleLabelInstance(double weight, double[] attributeValues,
+ int[] indexValues, int numberAttributes) {
+ this.weight = weight;
+ this.instanceData = new SparseInstanceData(attributeValues,
+ indexValues, numberAttributes); // ???
+ this.classData = new SingleClassInstanceData();
+ // this.classValue = Double.NaN;
+ // this.instanceInformation = new InstancesHeader();
- public SingleLabelInstance(int numAttributes) {
- this.instanceData = new DenseInstanceData(new double[numAttributes]);
- // m_AttValues = new double[numAttributes];
- /*
- * for (int i = 0; i < m_AttValues.length; i++) { m_AttValues[i] =
- * Utils.missingValue(); }
- */
- this.weight = 1;
- this.classData = new SingleClassInstanceData();
- this.instanceInformation = new InstancesHeader();
- }
+ }
- @Override
- public double weight() {
- return weight;
- }
+ public SingleLabelInstance(double weight, InstanceData instanceData) {
+ this.weight = weight;
+ this.instanceData = instanceData; // ???
+ // this.classValue = Double.NaN;
+ this.classData = new SingleClassInstanceData();
+ // this.instanceInformation = new InstancesHeader();
+ }
- @Override
- public void setWeight(double weight) {
- this.weight = weight;
- }
+ public SingleLabelInstance(int numAttributes) {
+ this.instanceData = new DenseInstanceData(new double[numAttributes]);
+ // m_AttValues = new double[numAttributes];
+ /*
+ * for (int i = 0; i < m_AttValues.length; i++) { m_AttValues[i] =
+ * Utils.missingValue(); }
+ */
+ this.weight = 1;
+ this.classData = new SingleClassInstanceData();
+ this.instanceInformation = new InstancesHeader();
+ }
- @Override
- public Attribute attribute(int instAttIndex) {
- return this.instanceInformation.attribute(instAttIndex);
- }
+ @Override
+ public double weight() {
+ return weight;
+ }
- @Override
- public void deleteAttributeAt(int i) {
- // throw new UnsupportedOperationException("Not yet implemented");
- }
+ @Override
+ public void setWeight(double weight) {
+ this.weight = weight;
+ }
- @Override
- public void insertAttributeAt(int i) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
+ @Override
+ public Attribute attribute(int instAttIndex) {
+ return this.instanceInformation.attribute(instAttIndex);
+ }
- @Override
- public int numAttributes() {
- return this.instanceInformation.numAttributes();
- }
+ @Override
+ public void deleteAttributeAt(int i) {
+ // throw new UnsupportedOperationException("Not yet implemented");
+ }
- @Override
- public double value(int instAttIndex) {
- return // attributeValues[instAttIndex]; //
- this.instanceData.value(instAttIndex);
- }
+ @Override
+ public void insertAttributeAt(int i) {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
- @Override
- public boolean isMissing(int instAttIndex) {
- return // Double.isNaN(value(instAttIndex)); //
- this.instanceData.isMissing(instAttIndex);
- }
+ @Override
+ public int numAttributes() {
+ return this.instanceInformation.numAttributes();
+ }
- @Override
- public int numValues() {
- return // this.attributeValues.length; //
- this.instanceData.numValues();
- }
+ @Override
+ public double value(int instAttIndex) {
+ return // attributeValues[instAttIndex]; //
+ this.instanceData.value(instAttIndex);
+ }
- @Override
- public int index(int i) {
- return // i; //
- this.instanceData.index(i);
- }
+ @Override
+ public boolean isMissing(int instAttIndex) {
+ return // Double.isNaN(value(instAttIndex)); //
+ this.instanceData.isMissing(instAttIndex);
+ }
- @Override
- public double valueSparse(int i) {
- return this.instanceData.valueSparse(i);
- }
+ @Override
+ public int numValues() {
+ return // this.attributeValues.length; //
+ this.instanceData.numValues();
+ }
- @Override
- public boolean isMissingSparse(int p) {
- return this.instanceData.isMissingSparse(p);
- }
+ @Override
+ public int index(int i) {
+ return // i; //
+ this.instanceData.index(i);
+ }
- @Override
- public double value(Attribute attribute) {
- // throw new UnsupportedOperationException("Not yet implemented");
- // //Predicates.java
- return value(attribute.index());
+ @Override
+ public double valueSparse(int i) {
+ return this.instanceData.valueSparse(i);
+ }
- }
+ @Override
+ public boolean isMissingSparse(int p) {
+ return this.instanceData.isMissingSparse(p);
+ }
- @Override
- public String stringValue(int i) {
- throw new UnsupportedOperationException("Not yet implemented");
- }
+ @Override
+ public double value(Attribute attribute) {
+ // throw new UnsupportedOperationException("Not yet implemented");
+ // //Predicates.java
+ return value(attribute.index());
- @Override
- public double[] toDoubleArray() {
- return // this.attributeValues; //
- this.instanceData.toDoubleArray();
- }
+ }
- @Override
- public void setValue(int numAttribute, double d) {
- this.instanceData.setValue(numAttribute, d);
- // this.attributeValues[numAttribute] = d;
- }
+ @Override
+ public String stringValue(int i) {
+ throw new UnsupportedOperationException("Not yet implemented");
+ }
- @Override
- public double classValue() {
- return this.classData.value(0);
- // return classValue;
- }
+ @Override
+ public double[] toDoubleArray() {
+ return // this.attributeValues; //
+ this.instanceData.toDoubleArray();
+ }
- @Override
- public int classIndex() {
- return instanceInformation.classIndex();
- }
+ @Override
+ public void setValue(int numAttribute, double d) {
+ this.instanceData.setValue(numAttribute, d);
+ // this.attributeValues[numAttribute] = d;
+ }
- @Override
- public int numClasses() {
- return this.instanceInformation.numClasses();
- }
+ @Override
+ public double classValue() {
+ return this.classData.value(0);
+ // return classValue;
+ }
- @Override
- public boolean classIsMissing() {
- return // Double.isNaN(this.classValue);//
- this.classData.isMissing(0);
- }
+ @Override
+ public int classIndex() {
+ return instanceInformation.classIndex();
+ }
- @Override
- public Attribute classAttribute() {
- return this.instanceInformation.attribute(0);
- }
+ @Override
+ public int numClasses() {
+ return this.instanceInformation.numClasses();
+ }
- @Override
- public void setClassValue(double d) {
- this.classData.setValue(0, d);
- // this.classValue = d;
- }
+ @Override
+ public boolean classIsMissing() {
+ return // Double.isNaN(this.classValue);//
+ this.classData.isMissing(0);
+ }
- @Override
- public Instance copy() {
- SingleLabelInstance inst = new SingleLabelInstance(this);
- return inst;
- }
+ @Override
+ public Attribute classAttribute() {
+ return this.instanceInformation.attribute(0);
+ }
- @Override
- public Instances dataset() {
- return this.instanceInformation;
- }
+ @Override
+ public void setClassValue(double d) {
+ this.classData.setValue(0, d);
+ // this.classValue = d;
+ }
- @Override
- public void setDataset(Instances dataset) {
- this.instanceInformation = new InstancesHeader(dataset);
- }
+ @Override
+ public Instance copy() {
+ SingleLabelInstance inst = new SingleLabelInstance(this);
+ return inst;
+ }
- public void addSparseValues(int[] indexValues, double[] attributeValues,
- int numberAttributes) {
- this.instanceData = new SparseInstanceData(attributeValues,
- indexValues, numberAttributes); // ???
- }
+ @Override
+ public Instances dataset() {
+ return this.instanceInformation;
+ }
- @Override
- public String toString() {
- StringBuffer text = new StringBuffer();
+ @Override
+ public void setDataset(Instances dataset) {
+ this.instanceInformation = new InstancesHeader(dataset);
+ }
- for (int i = 0; i < this.numValues() ; i++) {
- if (i > 0)
- text.append(",");
- text.append(this.value(i));
- }
- text.append(",").append(this.weight());
+ public void addSparseValues(int[] indexValues, double[] attributeValues,
+ int numberAttributes) {
+ this.instanceData = new SparseInstanceData(attributeValues,
+ indexValues, numberAttributes); // ???
+ }
- return text.toString();
- }
+ @Override
+ public String toString() {
+ StringBuffer text = new StringBuffer();
+
+ for (int i = 0; i < this.numValues(); i++) {
+ if (i > 0)
+ text.append(",");
+ text.append(this.value(i));
+ }
+ text.append(",").append(this.weight());
+
+ return text.toString();
+ }
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SparseInstance.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SparseInstance.java
index 66d0715..e55dee5 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SparseInstance.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SparseInstance.java
@@ -25,25 +25,26 @@
*/
/**
- *
+ *
* @author abifet
*/
-public class SparseInstance extends SingleLabelInstance{
-
- public SparseInstance(double d, double[] res) {
- super(d,res);
- }
- public SparseInstance(SingleLabelInstance inst) {
- super(inst);
- }
+public class SparseInstance extends SingleLabelInstance {
- public SparseInstance(double numberAttributes) {
- //super(1, new double[(int) numberAttributes-1]);
- super(1,null,null,(int) numberAttributes);
- }
-
- public SparseInstance(double weight, double[] attributeValues, int[] indexValues, int numberAttributes) {
- super(weight,attributeValues,indexValues,numberAttributes);
- }
-
+ public SparseInstance(double d, double[] res) {
+ super(d, res);
+ }
+
+ public SparseInstance(SingleLabelInstance inst) {
+ super(inst);
+ }
+
+ public SparseInstance(double numberAttributes) {
+ // super(1, new double[(int) numberAttributes-1]);
+ super(1, null, null, (int) numberAttributes);
+ }
+
+ public SparseInstance(double weight, double[] attributeValues, int[] indexValues, int numberAttributes) {
+ super(weight, attributeValues, indexValues, numberAttributes);
+ }
+
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SparseInstanceData.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SparseInstanceData.java
index e917844..1db95d0 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SparseInstanceData.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/SparseInstanceData.java
@@ -25,118 +25,118 @@
*/
/**
- *
+ *
* @author abifet
*/
-public class SparseInstanceData implements InstanceData{
-
- public SparseInstanceData(double[] attributeValues, int[] indexValues, int numberAttributes) {
- this.attributeValues = attributeValues;
- this.indexValues = indexValues;
- this.numberAttributes = numberAttributes;
- }
-
- public SparseInstanceData(int length) {
- this.attributeValues = new double[length];
- this.indexValues = new int[length];
- }
-
-
- protected double[] attributeValues;
+public class SparseInstanceData implements InstanceData {
- public double[] getAttributeValues() {
- return attributeValues;
- }
+ public SparseInstanceData(double[] attributeValues, int[] indexValues, int numberAttributes) {
+ this.attributeValues = attributeValues;
+ this.indexValues = indexValues;
+ this.numberAttributes = numberAttributes;
+ }
- public void setAttributeValues(double[] attributeValues) {
- this.attributeValues = attributeValues;
- }
+ public SparseInstanceData(int length) {
+ this.attributeValues = new double[length];
+ this.indexValues = new int[length];
+ }
- public int[] getIndexValues() {
- return indexValues;
- }
+ protected double[] attributeValues;
- public void setIndexValues(int[] indexValues) {
- this.indexValues = indexValues;
- }
+ public double[] getAttributeValues() {
+ return attributeValues;
+ }
- public int getNumberAttributes() {
- return numberAttributes;
- }
+ public void setAttributeValues(double[] attributeValues) {
+ this.attributeValues = attributeValues;
+ }
- public void setNumberAttributes(int numberAttributes) {
- this.numberAttributes = numberAttributes;
- }
- protected int[] indexValues;
- protected int numberAttributes;
+ public int[] getIndexValues() {
+ return indexValues;
+ }
- @Override
- public int numAttributes() {
- return this.numberAttributes;
- }
+ public void setIndexValues(int[] indexValues) {
+ this.indexValues = indexValues;
+ }
- @Override
- public double value(int indexAttribute) {
- int location = locateIndex(indexAttribute);
- //return location == -1 ? 0 : this.attributeValues[location];
- // int index = locateIndex(attIndex);
+ public int getNumberAttributes() {
+ return numberAttributes;
+ }
+
+ public void setNumberAttributes(int numberAttributes) {
+ this.numberAttributes = numberAttributes;
+ }
+
+ protected int[] indexValues;
+ protected int numberAttributes;
+
+ @Override
+ public int numAttributes() {
+ return this.numberAttributes;
+ }
+
+ @Override
+ public double value(int indexAttribute) {
+ int location = locateIndex(indexAttribute);
+ // return location == -1 ? 0 : this.attributeValues[location];
+ // int index = locateIndex(attIndex);
if ((location >= 0) && (indexValues[location] == indexAttribute)) {
return attributeValues[location];
} else {
return 0.0;
}
- }
+ }
- @Override
- public boolean isMissing(int indexAttribute) {
- return Double.isNaN(this.value(indexAttribute));
- }
+ @Override
+ public boolean isMissing(int indexAttribute) {
+ return Double.isNaN(this.value(indexAttribute));
+ }
- @Override
- public int numValues() {
- return this.attributeValues.length;
- }
+ @Override
+ public int numValues() {
+ return this.attributeValues.length;
+ }
- @Override
- public int index(int indexAttribute) {
- return this.indexValues[indexAttribute];
- }
+ @Override
+ public int index(int indexAttribute) {
+ return this.indexValues[indexAttribute];
+ }
- @Override
- public double valueSparse(int indexAttribute) {
- return this.attributeValues[indexAttribute];
- }
+ @Override
+ public double valueSparse(int indexAttribute) {
+ return this.attributeValues[indexAttribute];
+ }
- @Override
- public boolean isMissingSparse(int indexAttribute) {
- return Double.isNaN(this.valueSparse(indexAttribute));
- }
+ @Override
+ public boolean isMissingSparse(int indexAttribute) {
+ return Double.isNaN(this.valueSparse(indexAttribute));
+ }
- /*@Override
- public double value(Attribute attribute) {
- return value(attribute.index());
- }*/
+ /*
+ * @Override public double value(Attribute attribute) { return
+ * value(attribute.index()); }
+ */
- @Override
- public double[] toDoubleArray() {
- double[] array = new double[numAttributes()];
- for (int i=0; i<numValues() ; i++) {
- array[index(i)] = valueSparse(i);
- }
- return array;
+ @Override
+ public double[] toDoubleArray() {
+ double[] array = new double[numAttributes()];
+ for (int i = 0; i < numValues(); i++) {
+ array[index(i)] = valueSparse(i);
}
+ return array;
+ }
- @Override
- public void setValue(int attributeIndex, double d) {
- int index = locateIndex(attributeIndex);
- if (index(index) == attributeIndex) {
- this.attributeValues[index] = d;
- } else {
- // We need to add the value
- }
+ @Override
+ public void setValue(int attributeIndex, double d) {
+ int index = locateIndex(attributeIndex);
+ if (index(index) == attributeIndex) {
+ this.attributeValues[index] = d;
+ } else {
+ // We need to add the value
}
-
- /**
+ }
+
+ /**
* Locates the greatest index that is not greater than the given index.
*
* @return the internal index of the attribute index. Returns -1 if no index
@@ -168,5 +168,5 @@
return min - 1;
}
}
-
+
}
diff --git a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Utils.java b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Utils.java
index f3dc1b9..dd9df6d 100644
--- a/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Utils.java
+++ b/samoa-instances/src/main/java/com/yahoo/labs/samoa/instances/Utils.java
@@ -21,68 +21,71 @@
*/
public class Utils {
- public static int maxIndex(double[] doubles) {
+ public static int maxIndex(double[] doubles) {
- double maximum = 0;
- int maxIndex = 0;
+ double maximum = 0;
+ int maxIndex = 0;
- for (int i = 0; i < doubles.length; i++) {
- if ((i == 0) || (doubles[i] > maximum)) {
- maxIndex = i;
- maximum = doubles[i];
- }
- }
-
- return maxIndex;
+ for (int i = 0; i < doubles.length; i++) {
+ if ((i == 0) || (doubles[i] > maximum)) {
+ maxIndex = i;
+ maximum = doubles[i];
+ }
}
- public static String quote(String string) {
- boolean quote = false;
+ return maxIndex;
+ }
- // backquote the following characters
- if ((string.indexOf('\n') != -1) || (string.indexOf('\r') != -1) || (string.indexOf('\'') != -1) || (string.indexOf('"') != -1)
- || (string.indexOf('\\') != -1) || (string.indexOf('\t') != -1) || (string.indexOf('%') != -1) || (string.indexOf('\u001E') != -1)) {
- string = backQuoteChars(string);
- quote = true;
- }
+ public static String quote(String string) {
+ boolean quote = false;
- // Enclose the string in 's if the string contains a recently added
- // backquote or contains one of the following characters.
- if ((quote == true) || (string.indexOf('{') != -1) || (string.indexOf('}') != -1) || (string.indexOf(',') != -1) || (string.equals("?"))
- || (string.indexOf(' ') != -1) || (string.equals(""))) {
- string = ("'".concat(string)).concat("'");
- }
-
- return string;
+ // backquote the following characters
+ if ((string.indexOf('\n') != -1) || (string.indexOf('\r') != -1) || (string.indexOf('\'') != -1)
+ || (string.indexOf('"') != -1)
+ || (string.indexOf('\\') != -1) || (string.indexOf('\t') != -1) || (string.indexOf('%') != -1)
+ || (string.indexOf('\u001E') != -1)) {
+ string = backQuoteChars(string);
+ quote = true;
}
- public static String backQuoteChars(String string) {
-
- int index;
- StringBuffer newStringBuffer;
-
- // replace each of the following characters with the backquoted version
- char charsFind[] = { '\\', '\'', '\t', '\n', '\r', '"', '%', '\u001E' };
- String charsReplace[] = { "\\\\", "\\'", "\\t", "\\n", "\\r", "\\\"", "\\%", "\\u001E" };
- for (int i = 0; i < charsFind.length; i++) {
- if (string.indexOf(charsFind[i]) != -1) {
- newStringBuffer = new StringBuffer();
- while ((index = string.indexOf(charsFind[i])) != -1) {
- if (index > 0) {
- newStringBuffer.append(string.substring(0, index));
- }
- newStringBuffer.append(charsReplace[i]);
- if ((index + 1) < string.length()) {
- string = string.substring(index + 1);
- } else {
- string = "";
- }
- }
- newStringBuffer.append(string);
- string = newStringBuffer.toString();
- }
- }
-
- return string;
+ // Enclose the string in 's if the string contains a recently added
+ // backquote or contains one of the following characters.
+ if ((quote == true) || (string.indexOf('{') != -1) || (string.indexOf('}') != -1) || (string.indexOf(',') != -1)
+ || (string.equals("?"))
+ || (string.indexOf(' ') != -1) || (string.equals(""))) {
+ string = ("'".concat(string)).concat("'");
}
+
+ return string;
+ }
+
+ public static String backQuoteChars(String string) {
+
+ int index;
+ StringBuffer newStringBuffer;
+
+ // replace each of the following characters with the backquoted version
+ char charsFind[] = { '\\', '\'', '\t', '\n', '\r', '"', '%', '\u001E' };
+ String charsReplace[] = { "\\\\", "\\'", "\\t", "\\n", "\\r", "\\\"", "\\%", "\\u001E" };
+ for (int i = 0; i < charsFind.length; i++) {
+ if (string.indexOf(charsFind[i]) != -1) {
+ newStringBuffer = new StringBuffer();
+ while ((index = string.indexOf(charsFind[i])) != -1) {
+ if (index > 0) {
+ newStringBuffer.append(string.substring(0, index));
+ }
+ newStringBuffer.append(charsReplace[i]);
+ if ((index + 1) < string.length()) {
+ string = string.substring(index + 1);
+ } else {
+ string = "";
+ }
+ }
+ newStringBuffer.append(string);
+ string = newStringBuffer.toString();
+ }
+ }
+
+ return string;
+ }
}
diff --git a/samoa-local/src/main/java/com/yahoo/labs/samoa/LocalDoTask.java b/samoa-local/src/main/java/com/yahoo/labs/samoa/LocalDoTask.java
index 05ee1e1..27ca7a1 100644
--- a/samoa-local/src/main/java/com/yahoo/labs/samoa/LocalDoTask.java
+++ b/samoa-local/src/main/java/com/yahoo/labs/samoa/LocalDoTask.java
@@ -36,54 +36,55 @@
*/
public class LocalDoTask {
- // TODO: clean up this class for helping ML Developer in SAMOA
- // TODO: clean up code from storm-impl
-
- // It seems that the 3 extra options are not used.
- // Probably should remove them
- private static final String SUPPRESS_STATUS_OUT_MSG = "Suppress the task status output. Normally it is sent to stderr.";
- private static final String SUPPRESS_RESULT_OUT_MSG = "Suppress the task result output. Normally it is sent to stdout.";
- private static final String STATUS_UPDATE_FREQ_MSG = "Wait time in milliseconds between status updates.";
- private static final Logger logger = LoggerFactory.getLogger(LocalDoTask.class);
+ // TODO: clean up this class for helping ML Developer in SAMOA
+ // TODO: clean up code from storm-impl
- /**
- * The main method.
- *
- * @param args
- * the arguments
- */
- public static void main(String[] args) {
+ // It seems that the 3 extra options are not used.
+ // Probably should remove them
+ private static final String SUPPRESS_STATUS_OUT_MSG = "Suppress the task status output. Normally it is sent to stderr.";
+ private static final String SUPPRESS_RESULT_OUT_MSG = "Suppress the task result output. Normally it is sent to stdout.";
+ private static final String STATUS_UPDATE_FREQ_MSG = "Wait time in milliseconds between status updates.";
+ private static final Logger logger = LoggerFactory.getLogger(LocalDoTask.class);
- // ArrayList<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
+ /**
+ * The main method.
+ *
+ * @param args
+ * the arguments
+ */
+ public static void main(String[] args) {
- // args = tmpArgs.toArray(new String[0]);
+ // ArrayList<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
- FlagOption suppressStatusOutOpt = new FlagOption("suppressStatusOut", 'S', SUPPRESS_STATUS_OUT_MSG);
+ // args = tmpArgs.toArray(new String[0]);
- FlagOption suppressResultOutOpt = new FlagOption("suppressResultOut", 'R', SUPPRESS_RESULT_OUT_MSG);
+ FlagOption suppressStatusOutOpt = new FlagOption("suppressStatusOut", 'S', SUPPRESS_STATUS_OUT_MSG);
- IntOption statusUpdateFreqOpt = new IntOption("statusUpdateFrequency", 'F', STATUS_UPDATE_FREQ_MSG, 1000, 0, Integer.MAX_VALUE);
+ FlagOption suppressResultOutOpt = new FlagOption("suppressResultOut", 'R', SUPPRESS_RESULT_OUT_MSG);
- Option[] extraOptions = new Option[] { suppressStatusOutOpt, suppressResultOutOpt, statusUpdateFreqOpt };
+ IntOption statusUpdateFreqOpt = new IntOption("statusUpdateFrequency", 'F', STATUS_UPDATE_FREQ_MSG, 1000, 0,
+ Integer.MAX_VALUE);
- StringBuilder cliString = new StringBuilder();
- for (String arg : args) {
- cliString.append(" ").append(arg);
- }
- logger.debug("Command line string = {}", cliString.toString());
- System.out.println("Command line string = " + cliString.toString());
+ Option[] extraOptions = new Option[] { suppressStatusOutOpt, suppressResultOutOpt, statusUpdateFreqOpt };
- Task task;
- try {
- task = ClassOption.cliStringToObject(cliString.toString(), Task.class, extraOptions);
- logger.info("Successfully instantiating {}", task.getClass().getCanonicalName());
- } catch (Exception e) {
- logger.error("Fail to initialize the task", e);
- System.out.println("Fail to initialize the task" + e);
- return;
- }
- task.setFactory(new SimpleComponentFactory());
- task.init();
- SimpleEngine.submitTopology(task.getTopology());
+ StringBuilder cliString = new StringBuilder();
+ for (String arg : args) {
+ cliString.append(" ").append(arg);
}
+ logger.debug("Command line string = {}", cliString.toString());
+ System.out.println("Command line string = " + cliString.toString());
+
+ Task task;
+ try {
+ task = ClassOption.cliStringToObject(cliString.toString(), Task.class, extraOptions);
+ logger.info("Successfully instantiating {}", task.getClass().getCanonicalName());
+ } catch (Exception e) {
+ logger.error("Fail to initialize the task", e);
+ System.out.println("Fail to initialize the task" + e);
+ return;
+ }
+ task.setFactory(new SimpleComponentFactory());
+ task.init();
+ SimpleEngine.submitTopology(task.getTopology());
+ }
}
diff --git a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleComponentFactory.java b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleComponentFactory.java
index b289dbe..0c2f301 100644
--- a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleComponentFactory.java
+++ b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleComponentFactory.java
@@ -31,23 +31,23 @@
public class SimpleComponentFactory implements ComponentFactory {
- public ProcessingItem createPi(Processor processor, int paralellism) {
- return new SimpleProcessingItem(processor, paralellism);
- }
+ public ProcessingItem createPi(Processor processor, int paralellism) {
+ return new SimpleProcessingItem(processor, paralellism);
+ }
- public ProcessingItem createPi(Processor processor) {
- return this.createPi(processor, 1);
- }
+ public ProcessingItem createPi(Processor processor) {
+ return this.createPi(processor, 1);
+ }
- public EntranceProcessingItem createEntrancePi(EntranceProcessor processor) {
- return new SimpleEntranceProcessingItem(processor);
- }
+ public EntranceProcessingItem createEntrancePi(EntranceProcessor processor) {
+ return new SimpleEntranceProcessingItem(processor);
+ }
- public Stream createStream(IProcessingItem sourcePi) {
- return new SimpleStream(sourcePi);
- }
+ public Stream createStream(IProcessingItem sourcePi) {
+ return new SimpleStream(sourcePi);
+ }
- public Topology createTopology(String topoName) {
- return new SimpleTopology(topoName);
- }
+ public Topology createTopology(String topoName) {
+ return new SimpleTopology(topoName);
+ }
}
\ No newline at end of file
diff --git a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleEngine.java b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleEngine.java
index 9d131e1..5ca5837 100644
--- a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleEngine.java
+++ b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleEngine.java
@@ -28,10 +28,10 @@
public class SimpleEngine {
- public static void submitTopology(Topology topology) {
- SimpleTopology simpleTopology = (SimpleTopology) topology;
- simpleTopology.run();
- // runs until completion
- }
+ public static void submitTopology(Topology topology) {
+ SimpleTopology simpleTopology = (SimpleTopology) topology;
+ simpleTopology.run();
+ // runs until completion
+ }
}
diff --git a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleEntranceProcessingItem.java b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleEntranceProcessingItem.java
index 4652ebb..c9cc601 100644
--- a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleEntranceProcessingItem.java
+++ b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleEntranceProcessingItem.java
@@ -24,10 +24,10 @@
import com.yahoo.labs.samoa.topology.LocalEntranceProcessingItem;
class SimpleEntranceProcessingItem extends LocalEntranceProcessingItem {
- public SimpleEntranceProcessingItem(EntranceProcessor processor) {
- super(processor);
- }
-
- // The default waiting time when there is no available events is 100ms
- // Override waitForNewEvents() to change it
+ public SimpleEntranceProcessingItem(EntranceProcessor processor) {
+ super(processor);
+ }
+
+ // The default waiting time when there is no available events is 100ms
+ // Override waitForNewEvents() to change it
}
diff --git a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleProcessingItem.java b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleProcessingItem.java
index e3cc765..77361b1 100644
--- a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleProcessingItem.java
+++ b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleProcessingItem.java
@@ -34,54 +34,54 @@
import com.yahoo.labs.samoa.utils.StreamDestination;
/**
- *
+ *
* @author abifet
*/
class SimpleProcessingItem extends AbstractProcessingItem {
- private IProcessingItem[] arrayProcessingItem;
+ private IProcessingItem[] arrayProcessingItem;
- SimpleProcessingItem(Processor processor) {
- super(processor);
- }
-
- SimpleProcessingItem(Processor processor, int parallelism) {
- super(processor);
- this.setParallelism(parallelism);
- }
-
- public IProcessingItem getProcessingItem(int i) {
- return arrayProcessingItem[i];
- }
-
- @Override
- protected ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme) {
- StreamDestination destination = new StreamDestination(this, this.getParallelism(), scheme);
- ((SimpleStream)inputStream).addDestination(destination);
- return this;
- }
+ SimpleProcessingItem(Processor processor) {
+ super(processor);
+ }
- public SimpleProcessingItem copy() {
- Processor processor = this.getProcessor();
- return new SimpleProcessingItem(processor.newProcessor(processor));
- }
+ SimpleProcessingItem(Processor processor, int parallelism) {
+ super(processor);
+ this.setParallelism(parallelism);
+ }
- public void processEvent(ContentEvent event, int counter) {
-
- int parallelism = this.getParallelism();
- //System.out.println("Process event "+event+" (isLast="+event.isLastEvent()+") with counter="+counter+" while parallelism="+parallelism);
- if (this.arrayProcessingItem == null && parallelism > 0) {
- //Init processing elements, the first time they are needed
- this.arrayProcessingItem = new IProcessingItem[parallelism];
- for (int j = 0; j < parallelism; j++) {
- arrayProcessingItem[j] = this.copy();
- arrayProcessingItem[j].getProcessor().onCreate(j);
- }
- }
- if (this.arrayProcessingItem != null) {
- IProcessingItem pi = this.getProcessingItem(counter);
- Processor p = pi.getProcessor();
- //System.out.println("PI="+pi+", p="+p);
- this.getProcessingItem(counter).getProcessor().process(event);
- }
+ public IProcessingItem getProcessingItem(int i) {
+ return arrayProcessingItem[i];
+ }
+
+ @Override
+ protected ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme) {
+ StreamDestination destination = new StreamDestination(this, this.getParallelism(), scheme);
+ ((SimpleStream) inputStream).addDestination(destination);
+ return this;
+ }
+
+ public SimpleProcessingItem copy() {
+ Processor processor = this.getProcessor();
+ return new SimpleProcessingItem(processor.newProcessor(processor));
+ }
+
+ public void processEvent(ContentEvent event, int counter) {
+
+ int parallelism = this.getParallelism();
+ // System.out.println("Process event "+event+" (isLast="+event.isLastEvent()+") with counter="+counter+" while parallelism="+parallelism);
+ if (this.arrayProcessingItem == null && parallelism > 0) {
+ // Init processing elements, the first time they are needed
+ this.arrayProcessingItem = new IProcessingItem[parallelism];
+ for (int j = 0; j < parallelism; j++) {
+ arrayProcessingItem[j] = this.copy();
+ arrayProcessingItem[j].getProcessor().onCreate(j);
+ }
}
+ if (this.arrayProcessingItem != null) {
+ IProcessingItem pi = this.getProcessingItem(counter);
+ Processor p = pi.getProcessor();
+ // System.out.println("PI="+pi+", p="+p);
+ this.getProcessingItem(counter).getProcessor().process(event);
+ }
+ }
}
diff --git a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleStream.java b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleStream.java
index 74684a7..09dc555 100644
--- a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleStream.java
+++ b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleStream.java
@@ -38,56 +38,58 @@
* @author abifet
*/
class SimpleStream extends AbstractStream {
- private List<StreamDestination> destinations;
- private int maxCounter;
- private int eventCounter;
+ private List<StreamDestination> destinations;
+ private int maxCounter;
+ private int eventCounter;
- SimpleStream(IProcessingItem sourcePi) {
- super(sourcePi);
- this.destinations = new LinkedList<>();
- this.eventCounter = 0;
- this.maxCounter = 1;
- }
+ SimpleStream(IProcessingItem sourcePi) {
+ super(sourcePi);
+ this.destinations = new LinkedList<>();
+ this.eventCounter = 0;
+ this.maxCounter = 1;
+ }
- private int getNextCounter() {
- if (maxCounter > 0 && eventCounter >= maxCounter) eventCounter = 0;
- this.eventCounter++;
- return this.eventCounter;
- }
+ private int getNextCounter() {
+ if (maxCounter > 0 && eventCounter >= maxCounter)
+ eventCounter = 0;
+ this.eventCounter++;
+ return this.eventCounter;
+ }
- @Override
- public void put(ContentEvent event) {
- this.put(event, this.getNextCounter());
- }
-
- private void put(ContentEvent event, int counter) {
- SimpleProcessingItem pi;
- int parallelism;
- for (StreamDestination destination:destinations) {
- pi = (SimpleProcessingItem) destination.getProcessingItem();
- parallelism = destination.getParallelism();
- switch (destination.getPartitioningScheme()) {
- case SHUFFLE:
- pi.processEvent(event, counter % parallelism);
- break;
- case GROUP_BY_KEY:
- HashCodeBuilder hb = new HashCodeBuilder();
- hb.append(event.getKey());
- int key = hb.build() % parallelism;
- pi.processEvent(event, key);
- break;
- case BROADCAST:
- for (int p = 0; p < parallelism; p++) {
- pi.processEvent(event, p);
- }
- break;
- }
+ @Override
+ public void put(ContentEvent event) {
+ this.put(event, this.getNextCounter());
+ }
+
+ private void put(ContentEvent event, int counter) {
+ SimpleProcessingItem pi;
+ int parallelism;
+ for (StreamDestination destination : destinations) {
+ pi = (SimpleProcessingItem) destination.getProcessingItem();
+ parallelism = destination.getParallelism();
+ switch (destination.getPartitioningScheme()) {
+ case SHUFFLE:
+ pi.processEvent(event, counter % parallelism);
+ break;
+ case GROUP_BY_KEY:
+ HashCodeBuilder hb = new HashCodeBuilder();
+ hb.append(event.getKey());
+ int key = hb.build() % parallelism;
+ pi.processEvent(event, key);
+ break;
+ case BROADCAST:
+ for (int p = 0; p < parallelism; p++) {
+ pi.processEvent(event, p);
}
+ break;
+ }
}
+ }
- public void addDestination(StreamDestination destination) {
- this.destinations.add(destination);
- if (maxCounter <= 0) maxCounter = 1;
- maxCounter *= destination.getParallelism();
- }
+ public void addDestination(StreamDestination destination) {
+ this.destinations.add(destination);
+ if (maxCounter <= 0)
+ maxCounter = 1;
+ maxCounter *= destination.getParallelism();
+ }
}
diff --git a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleTopology.java b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleTopology.java
index 675b4ac..e7fddbd 100644
--- a/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleTopology.java
+++ b/samoa-local/src/main/java/com/yahoo/labs/samoa/topology/impl/SimpleTopology.java
@@ -27,18 +27,21 @@
import com.yahoo.labs.samoa.topology.AbstractTopology;
public class SimpleTopology extends AbstractTopology {
- SimpleTopology(String name) {
- super(name);
- }
+ SimpleTopology(String name) {
+ super(name);
+ }
- public void run() {
- if (this.getEntranceProcessingItems() == null)
- throw new IllegalStateException("You need to set entrance PI before running the topology.");
- if (this.getEntranceProcessingItems().size() != 1)
- throw new IllegalStateException("SimpleTopology supports 1 entrance PI only. Number of entrance PIs is "+this.getEntranceProcessingItems().size());
-
- SimpleEntranceProcessingItem entrancePi = (SimpleEntranceProcessingItem) this.getEntranceProcessingItems().toArray()[0];
- entrancePi.getProcessor().onCreate(0); // id=0 as it is not used in simple mode
- entrancePi.startSendingEvents();
- }
+ public void run() {
+ if (this.getEntranceProcessingItems() == null)
+ throw new IllegalStateException("You need to set entrance PI before running the topology.");
+ if (this.getEntranceProcessingItems().size() != 1)
+ throw new IllegalStateException("SimpleTopology supports 1 entrance PI only. Number of entrance PIs is "
+ + this.getEntranceProcessingItems().size());
+
+ SimpleEntranceProcessingItem entrancePi = (SimpleEntranceProcessingItem) this.getEntranceProcessingItems()
+ .toArray()[0];
+ entrancePi.getProcessor().onCreate(0); // id=0 as it is not used in simple
+ // mode
+ entrancePi.startSendingEvents();
+ }
}
diff --git a/samoa-local/src/test/java/com/yahoo/labs/samoa/AlgosTest.java b/samoa-local/src/test/java/com/yahoo/labs/samoa/AlgosTest.java
index 9bf1c2d..d3e54a8 100644
--- a/samoa-local/src/test/java/com/yahoo/labs/samoa/AlgosTest.java
+++ b/samoa-local/src/test/java/com/yahoo/labs/samoa/AlgosTest.java
@@ -24,64 +24,63 @@
public class AlgosTest {
+ @Test
+ public void testVHTLocal() throws Exception {
- @Test
- public void testVHTLocal() throws Exception {
+ TestParams vhtConfig = new TestParams.Builder()
+ .inputInstances(200_000)
+ .samplingSize(20_000)
+ .evaluationInstances(200_000)
+ .classifiedInstances(200_000)
+ .classificationsCorrect(75f)
+ .kappaStat(0f)
+ .kappaTempStat(0f)
+ .cliStringTemplate(TestParams.Templates.PREQEVAL_VHT_RANDOMTREE)
+ .resultFilePollTimeout(10)
+ .prePollWait(10)
+ .taskClassName(LocalDoTask.class.getName())
+ .build();
+ TestUtils.test(vhtConfig);
- TestParams vhtConfig = new TestParams.Builder()
- .inputInstances(200_000)
- .samplingSize(20_000)
- .evaluationInstances(200_000)
- .classifiedInstances(200_000)
- .classificationsCorrect(75f)
- .kappaStat(0f)
- .kappaTempStat(0f)
- .cliStringTemplate(TestParams.Templates.PREQEVAL_VHT_RANDOMTREE)
- .resultFilePollTimeout(10)
- .prePollWait(10)
- .taskClassName(LocalDoTask.class.getName())
- .build();
- TestUtils.test(vhtConfig);
+ }
- }
+ @Test
+ public void testBaggingLocal() throws Exception {
+ TestParams baggingConfig = new TestParams.Builder()
+ .inputInstances(200_000)
+ .samplingSize(20_000)
+ .evaluationInstances(180_000)
+ .classifiedInstances(210_000)
+ .classificationsCorrect(60f)
+ .kappaStat(0f)
+ .kappaTempStat(0f)
+ .cliStringTemplate(TestParams.Templates.PREQEVAL_BAGGING_RANDOMTREE)
+ .prePollWait(10)
+ .resultFilePollTimeout(10)
+ .taskClassName(LocalDoTask.class.getName())
+ .build();
+ TestUtils.test(baggingConfig);
- @Test
- public void testBaggingLocal() throws Exception {
- TestParams baggingConfig = new TestParams.Builder()
- .inputInstances(200_000)
- .samplingSize(20_000)
- .evaluationInstances(180_000)
- .classifiedInstances(210_000)
- .classificationsCorrect(60f)
- .kappaStat(0f)
- .kappaTempStat(0f)
- .cliStringTemplate(TestParams.Templates.PREQEVAL_BAGGING_RANDOMTREE)
- .prePollWait(10)
- .resultFilePollTimeout(10)
- .taskClassName(LocalDoTask.class.getName())
- .build();
- TestUtils.test(baggingConfig);
+ }
- }
+ @Test
+ public void testNaiveBayesLocal() throws Exception {
- @Test
- public void testNaiveBayesLocal() throws Exception {
+ TestParams vhtConfig = new TestParams.Builder()
+ .inputInstances(200_000)
+ .samplingSize(20_000)
+ .evaluationInstances(200_000)
+ .classifiedInstances(200_000)
+ .classificationsCorrect(65f)
+ .kappaStat(0f)
+ .kappaTempStat(0f)
+ .cliStringTemplate(TestParams.Templates.PREQEVAL_NAIVEBAYES_HYPERPLANE)
+ .resultFilePollTimeout(10)
+ .prePollWait(10)
+ .taskClassName(LocalDoTask.class.getName())
+ .build();
+ TestUtils.test(vhtConfig);
- TestParams vhtConfig = new TestParams.Builder()
- .inputInstances(200_000)
- .samplingSize(20_000)
- .evaluationInstances(200_000)
- .classifiedInstances(200_000)
- .classificationsCorrect(65f)
- .kappaStat(0f)
- .kappaTempStat(0f)
- .cliStringTemplate(TestParams.Templates.PREQEVAL_NAIVEBAYES_HYPERPLANE)
- .resultFilePollTimeout(10)
- .prePollWait(10)
- .taskClassName(LocalDoTask.class.getName())
- .build();
- TestUtils.test(vhtConfig);
-
- }
+ }
}
diff --git a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleComponentFactoryTest.java b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleComponentFactoryTest.java
index 02a9295..bfd6fe1 100644
--- a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleComponentFactoryTest.java
+++ b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleComponentFactoryTest.java
@@ -36,61 +36,64 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class SimpleComponentFactoryTest {
- @Tested private SimpleComponentFactory factory;
- @Mocked private Processor processor, processorReplica;
- @Mocked private EntranceProcessor entranceProcessor;
-
- private final int parallelism = 3;
- private final String topoName = "TestTopology";
-
+ @Tested
+ private SimpleComponentFactory factory;
+ @Mocked
+ private Processor processor, processorReplica;
+ @Mocked
+ private EntranceProcessor entranceProcessor;
- @Before
- public void setUp() throws Exception {
- factory = new SimpleComponentFactory();
- }
+ private final int parallelism = 3;
+ private final String topoName = "TestTopology";
- @Test
- public void testCreatePiNoParallelism() {
- ProcessingItem pi = factory.createPi(processor);
- assertNotNull("ProcessingItem created is null.",pi);
- assertEquals("ProcessingItem created is not a SimpleProcessingItem.",SimpleProcessingItem.class,pi.getClass());
- assertEquals("Parallelism of PI is not 1",1,pi.getParallelism(),0);
- }
-
- @Test
- public void testCreatePiWithParallelism() {
- ProcessingItem pi = factory.createPi(processor,parallelism);
- assertNotNull("ProcessingItem created is null.",pi);
- assertEquals("ProcessingItem created is not a SimpleProcessingItem.",SimpleProcessingItem.class,pi.getClass());
- assertEquals("Parallelism of PI is not ",parallelism,pi.getParallelism(),0);
- }
-
- @Test
- public void testCreateStream() {
- ProcessingItem pi = factory.createPi(processor);
-
- Stream stream = factory.createStream(pi);
- assertNotNull("Stream created is null",stream);
- assertEquals("Stream created is not a SimpleStream.",SimpleStream.class,stream.getClass());
- }
-
- @Test
- public void testCreateTopology() {
- Topology topology = factory.createTopology(topoName);
- assertNotNull("Topology created is null.",topology);
- assertEquals("Topology created is not a SimpleTopology.",SimpleTopology.class,topology.getClass());
- }
-
- @Test
- public void testCreateEntrancePi() {
- EntranceProcessingItem entrancePi = factory.createEntrancePi(entranceProcessor);
- assertNotNull("EntranceProcessingItem created is null.",entrancePi);
- assertEquals("EntranceProcessingItem created is not a SimpleEntranceProcessingItem.",SimpleEntranceProcessingItem.class,entrancePi.getClass());
- assertSame("EntranceProcessor is not set correctly.",entranceProcessor, entrancePi.getProcessor());
- }
+ @Before
+ public void setUp() throws Exception {
+ factory = new SimpleComponentFactory();
+ }
+
+ @Test
+ public void testCreatePiNoParallelism() {
+ ProcessingItem pi = factory.createPi(processor);
+ assertNotNull("ProcessingItem created is null.", pi);
+ assertEquals("ProcessingItem created is not a SimpleProcessingItem.", SimpleProcessingItem.class, pi.getClass());
+ assertEquals("Parallelism of PI is not 1", 1, pi.getParallelism(), 0);
+ }
+
+ @Test
+ public void testCreatePiWithParallelism() {
+ ProcessingItem pi = factory.createPi(processor, parallelism);
+ assertNotNull("ProcessingItem created is null.", pi);
+ assertEquals("ProcessingItem created is not a SimpleProcessingItem.", SimpleProcessingItem.class, pi.getClass());
+ assertEquals("Parallelism of PI is not ", parallelism, pi.getParallelism(), 0);
+ }
+
+ @Test
+ public void testCreateStream() {
+ ProcessingItem pi = factory.createPi(processor);
+
+ Stream stream = factory.createStream(pi);
+ assertNotNull("Stream created is null", stream);
+ assertEquals("Stream created is not a SimpleStream.", SimpleStream.class, stream.getClass());
+ }
+
+ @Test
+ public void testCreateTopology() {
+ Topology topology = factory.createTopology(topoName);
+ assertNotNull("Topology created is null.", topology);
+ assertEquals("Topology created is not a SimpleTopology.", SimpleTopology.class, topology.getClass());
+ }
+
+ @Test
+ public void testCreateEntrancePi() {
+ EntranceProcessingItem entrancePi = factory.createEntrancePi(entranceProcessor);
+ assertNotNull("EntranceProcessingItem created is null.", entrancePi);
+ assertEquals("EntranceProcessingItem created is not a SimpleEntranceProcessingItem.",
+ SimpleEntranceProcessingItem.class, entrancePi.getClass());
+ assertSame("EntranceProcessor is not set correctly.", entranceProcessor, entrancePi.getProcessor());
+ }
}
diff --git a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleEngineTest.java b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleEngineTest.java
index c4649ed..23b38b4 100644
--- a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleEngineTest.java
+++ b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleEngineTest.java
@@ -29,29 +29,32 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class SimpleEngineTest {
- @Tested private SimpleEngine unused;
- @Mocked private SimpleTopology topology;
- @Mocked private Runtime mockedRuntime;
-
- @Test
- public void testSubmitTopology() {
- new NonStrictExpectations() {
- {
- Runtime.getRuntime();
- result=mockedRuntime;
- mockedRuntime.exit(0);
- }
- };
- SimpleEngine.submitTopology(topology);
- new Verifications() {
- {
- topology.run();
- }
- };
- }
+ @Tested
+ private SimpleEngine unused;
+ @Mocked
+ private SimpleTopology topology;
+ @Mocked
+ private Runtime mockedRuntime;
+
+ @Test
+ public void testSubmitTopology() {
+ new NonStrictExpectations() {
+ {
+ Runtime.getRuntime();
+ result = mockedRuntime;
+ mockedRuntime.exit(0);
+ }
+ };
+ SimpleEngine.submitTopology(topology);
+ new Verifications() {
+ {
+ topology.run();
+ }
+ };
+ }
}
diff --git a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleEntranceProcessingItemTest.java b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleEntranceProcessingItemTest.java
index 41ae22b..0c1e475 100644
--- a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleEntranceProcessingItemTest.java
+++ b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleEntranceProcessingItemTest.java
@@ -36,118 +36,137 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class SimpleEntranceProcessingItemTest {
- @Tested private SimpleEntranceProcessingItem entrancePi;
-
- @Mocked private EntranceProcessor entranceProcessor;
- @Mocked private Stream outputStream, anotherStream;
- @Mocked private ContentEvent event;
-
- @Mocked private Thread unused;
-
- /**
- * @throws java.lang.Exception
- */
- @Before
- public void setUp() throws Exception {
- entrancePi = new SimpleEntranceProcessingItem(entranceProcessor);
- }
+ @Tested
+ private SimpleEntranceProcessingItem entrancePi;
- @Test
- public void testContructor() {
- assertSame("EntranceProcessor is not set correctly.",entranceProcessor,entrancePi.getProcessor());
- }
-
- @Test
- public void testSetOutputStream() {
- entrancePi.setOutputStream(outputStream);
- assertSame("OutputStream is not set correctly.",outputStream,entrancePi.getOutputStream());
- }
-
- @Test
- public void testSetOutputStreamRepeate() {
- entrancePi.setOutputStream(outputStream);
- entrancePi.setOutputStream(outputStream);
- assertSame("OutputStream is not set correctly.",outputStream,entrancePi.getOutputStream());
- }
-
- @Test(expected=IllegalStateException.class)
- public void testSetOutputStreamError() {
- entrancePi.setOutputStream(outputStream);
- entrancePi.setOutputStream(anotherStream);
- }
-
- @Test
- public void testInjectNextEventSuccess() {
- entrancePi.setOutputStream(outputStream);
- new StrictExpectations() {
- {
- entranceProcessor.hasNext();
- result=true;
-
- entranceProcessor.nextEvent();
- result=event;
- }
- };
- entrancePi.injectNextEvent();
- new Verifications() {
- {
- outputStream.put(event);
- }
- };
- }
-
- @Test
- public void testStartSendingEvents() {
- entrancePi.setOutputStream(outputStream);
- new StrictExpectations() {
- {
- for (int i=0; i<1; i++) {
- entranceProcessor.isFinished(); result=false;
- entranceProcessor.hasNext(); result=false;
- }
-
- for (int i=0; i<5; i++) {
- entranceProcessor.isFinished(); result=false;
- entranceProcessor.hasNext(); result=true;
- entranceProcessor.nextEvent(); result=event;
- outputStream.put(event);
- }
-
- for (int i=0; i<2; i++) {
- entranceProcessor.isFinished(); result=false;
- entranceProcessor.hasNext(); result=false;
- }
-
- for (int i=0; i<5; i++) {
- entranceProcessor.isFinished(); result=false;
- entranceProcessor.hasNext(); result=true;
- entranceProcessor.nextEvent(); result=event;
- outputStream.put(event);
- }
+ @Mocked
+ private EntranceProcessor entranceProcessor;
+ @Mocked
+ private Stream outputStream, anotherStream;
+ @Mocked
+ private ContentEvent event;
- entranceProcessor.isFinished(); result=true; times=1;
- entranceProcessor.hasNext(); times=0;
- }
- };
- entrancePi.startSendingEvents();
- new Verifications() {
- {
- try {
- Thread.sleep(anyInt); times=3;
- } catch (InterruptedException e) {
+ @Mocked
+ private Thread unused;
- }
- }
- };
- }
-
- @Test(expected=IllegalStateException.class)
- public void testStartSendingEventsError() {
- entrancePi.startSendingEvents();
- }
+ /**
+ * @throws java.lang.Exception
+ */
+ @Before
+ public void setUp() throws Exception {
+ entrancePi = new SimpleEntranceProcessingItem(entranceProcessor);
+ }
+
+ @Test
+ public void testContructor() {
+ assertSame("EntranceProcessor is not set correctly.", entranceProcessor, entrancePi.getProcessor());
+ }
+
+ @Test
+ public void testSetOutputStream() {
+ entrancePi.setOutputStream(outputStream);
+ assertSame("OutputStream is not set correctly.", outputStream, entrancePi.getOutputStream());
+ }
+
+ @Test
+ public void testSetOutputStreamRepeate() {
+ entrancePi.setOutputStream(outputStream);
+ entrancePi.setOutputStream(outputStream);
+ assertSame("OutputStream is not set correctly.", outputStream, entrancePi.getOutputStream());
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testSetOutputStreamError() {
+ entrancePi.setOutputStream(outputStream);
+ entrancePi.setOutputStream(anotherStream);
+ }
+
+ @Test
+ public void testInjectNextEventSuccess() {
+ entrancePi.setOutputStream(outputStream);
+ new StrictExpectations() {
+ {
+ entranceProcessor.hasNext();
+ result = true;
+
+ entranceProcessor.nextEvent();
+ result = event;
+ }
+ };
+ entrancePi.injectNextEvent();
+ new Verifications() {
+ {
+ outputStream.put(event);
+ }
+ };
+ }
+
+ @Test
+ public void testStartSendingEvents() {
+ entrancePi.setOutputStream(outputStream);
+ new StrictExpectations() {
+ {
+ for (int i = 0; i < 1; i++) {
+ entranceProcessor.isFinished();
+ result = false;
+ entranceProcessor.hasNext();
+ result = false;
+ }
+
+ for (int i = 0; i < 5; i++) {
+ entranceProcessor.isFinished();
+ result = false;
+ entranceProcessor.hasNext();
+ result = true;
+ entranceProcessor.nextEvent();
+ result = event;
+ outputStream.put(event);
+ }
+
+ for (int i = 0; i < 2; i++) {
+ entranceProcessor.isFinished();
+ result = false;
+ entranceProcessor.hasNext();
+ result = false;
+ }
+
+ for (int i = 0; i < 5; i++) {
+ entranceProcessor.isFinished();
+ result = false;
+ entranceProcessor.hasNext();
+ result = true;
+ entranceProcessor.nextEvent();
+ result = event;
+ outputStream.put(event);
+ }
+
+ entranceProcessor.isFinished();
+ result = true;
+ times = 1;
+ entranceProcessor.hasNext();
+ times = 0;
+ }
+ };
+ entrancePi.startSendingEvents();
+ new Verifications() {
+ {
+ try {
+ Thread.sleep(anyInt);
+ times = 3;
+ } catch (InterruptedException e) {
+
+ }
+ }
+ };
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testStartSendingEventsError() {
+ entrancePi.startSendingEvents();
+ }
}
diff --git a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleProcessingItemTest.java b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleProcessingItemTest.java
index a4a288a..caa82bf 100644
--- a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleProcessingItemTest.java
+++ b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleProcessingItemTest.java
@@ -40,81 +40,85 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class SimpleProcessingItemTest {
- @Tested private SimpleProcessingItem pi;
-
- @Mocked private Processor processor;
- @Mocked private SimpleStream stream;
- @Mocked private StreamDestination destination;
- @Mocked private ContentEvent event;
-
- private final int parallelism = 4;
- private final int counter = 2;
-
-
- @Before
- public void setUp() throws Exception {
- pi = new SimpleProcessingItem(processor, parallelism);
- }
+ @Tested
+ private SimpleProcessingItem pi;
- @Test
- public void testConstructor() {
- assertSame("Processor was not set correctly.",processor,pi.getProcessor());
- assertEquals("Parallelism was not set correctly.",parallelism,pi.getParallelism(),0);
- }
-
- @Test
- public void testConnectInputShuffleStream() {
- new Expectations() {
- {
- destination = new StreamDestination(pi, parallelism, PartitioningScheme.SHUFFLE);
- stream.addDestination(destination);
- }
- };
- pi.connectInputShuffleStream(stream);
- }
-
- @Test
- public void testConnectInputKeyStream() {
- new Expectations() {
- {
- destination = new StreamDestination(pi, parallelism, PartitioningScheme.GROUP_BY_KEY);
- stream.addDestination(destination);
- }
- };
- pi.connectInputKeyStream(stream);
- }
-
- @Test
- public void testConnectInputAllStream() {
- new Expectations() {
- {
- destination = new StreamDestination(pi, parallelism, PartitioningScheme.BROADCAST);
- stream.addDestination(destination);
- }
- };
- pi.connectInputAllStream(stream);
- }
-
- @Test
- public void testProcessEvent() {
- new Expectations() {
- {
- for (int i=0; i<parallelism; i++) {
- processor.newProcessor(processor);
- result=processor;
-
- processor.onCreate(anyInt);
- }
-
- processor.process(event);
- }
- };
- pi.processEvent(event, counter);
-
- }
+ @Mocked
+ private Processor processor;
+ @Mocked
+ private SimpleStream stream;
+ @Mocked
+ private StreamDestination destination;
+ @Mocked
+ private ContentEvent event;
+
+ private final int parallelism = 4;
+ private final int counter = 2;
+
+ @Before
+ public void setUp() throws Exception {
+ pi = new SimpleProcessingItem(processor, parallelism);
+ }
+
+ @Test
+ public void testConstructor() {
+ assertSame("Processor was not set correctly.", processor, pi.getProcessor());
+ assertEquals("Parallelism was not set correctly.", parallelism, pi.getParallelism(), 0);
+ }
+
+ @Test
+ public void testConnectInputShuffleStream() {
+ new Expectations() {
+ {
+ destination = new StreamDestination(pi, parallelism, PartitioningScheme.SHUFFLE);
+ stream.addDestination(destination);
+ }
+ };
+ pi.connectInputShuffleStream(stream);
+ }
+
+ @Test
+ public void testConnectInputKeyStream() {
+ new Expectations() {
+ {
+ destination = new StreamDestination(pi, parallelism, PartitioningScheme.GROUP_BY_KEY);
+ stream.addDestination(destination);
+ }
+ };
+ pi.connectInputKeyStream(stream);
+ }
+
+ @Test
+ public void testConnectInputAllStream() {
+ new Expectations() {
+ {
+ destination = new StreamDestination(pi, parallelism, PartitioningScheme.BROADCAST);
+ stream.addDestination(destination);
+ }
+ };
+ pi.connectInputAllStream(stream);
+ }
+
+ @Test
+ public void testProcessEvent() {
+ new Expectations() {
+ {
+ for (int i = 0; i < parallelism; i++) {
+ processor.newProcessor(processor);
+ result = processor;
+
+ processor.onCreate(anyInt);
+ }
+
+ processor.process(event);
+ }
+ };
+ pi.processEvent(event, counter);
+
+ }
}
diff --git a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleStreamTest.java b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleStreamTest.java
index 2a625b5..c8f6c5d 100644
--- a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleStreamTest.java
+++ b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleStreamTest.java
@@ -40,72 +40,82 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
@RunWith(Parameterized.class)
public class SimpleStreamTest {
- @Tested private SimpleStream stream;
-
- @Mocked private SimpleProcessingItem sourcePi, destPi;
- @Mocked private ContentEvent event;
- @Mocked private StreamDestination destination;
+ @Tested
+ private SimpleStream stream;
- private final String eventKey = "eventkey";
- private final int parallelism;
- private final PartitioningScheme scheme;
-
-
- @Parameters
- public static Collection<Object[]> generateParameters() {
- return Arrays.asList(new Object[][] {
- { 2, PartitioningScheme.SHUFFLE },
- { 3, PartitioningScheme.GROUP_BY_KEY },
- { 4, PartitioningScheme.BROADCAST }
- });
- }
-
- public SimpleStreamTest(int parallelism, PartitioningScheme scheme) {
- this.parallelism = parallelism;
- this.scheme = scheme;
- }
-
- @Before
- public void setUp() throws Exception {
- stream = new SimpleStream(sourcePi);
- stream.addDestination(destination);
- }
+ @Mocked
+ private SimpleProcessingItem sourcePi, destPi;
+ @Mocked
+ private ContentEvent event;
+ @Mocked
+ private StreamDestination destination;
- @Test
- public void testPut() {
- new NonStrictExpectations() {
- {
- event.getKey(); result=eventKey;
- destination.getProcessingItem(); result=destPi;
- destination.getPartitioningScheme(); result=scheme;
- destination.getParallelism(); result=parallelism;
-
- }
- };
- switch(this.scheme) {
- case SHUFFLE: case GROUP_BY_KEY:
- new Expectations() {
- {
- // TODO: restrict the range of counter value
- destPi.processEvent(event, anyInt); times=1;
- }
- };
- break;
- case BROADCAST:
- new Expectations() {
- {
- // TODO: restrict the range of counter value
- destPi.processEvent(event, anyInt); times=parallelism;
- }
- };
- break;
- }
- stream.put(event);
- }
+ private final String eventKey = "eventkey";
+ private final int parallelism;
+ private final PartitioningScheme scheme;
+
+ @Parameters
+ public static Collection<Object[]> generateParameters() {
+ return Arrays.asList(new Object[][] {
+ { 2, PartitioningScheme.SHUFFLE },
+ { 3, PartitioningScheme.GROUP_BY_KEY },
+ { 4, PartitioningScheme.BROADCAST }
+ });
+ }
+
+ public SimpleStreamTest(int parallelism, PartitioningScheme scheme) {
+ this.parallelism = parallelism;
+ this.scheme = scheme;
+ }
+
+ @Before
+ public void setUp() throws Exception {
+ stream = new SimpleStream(sourcePi);
+ stream.addDestination(destination);
+ }
+
+ @Test
+ public void testPut() {
+ new NonStrictExpectations() {
+ {
+ event.getKey();
+ result = eventKey;
+ destination.getProcessingItem();
+ result = destPi;
+ destination.getPartitioningScheme();
+ result = scheme;
+ destination.getParallelism();
+ result = parallelism;
+
+ }
+ };
+ switch (this.scheme) {
+ case SHUFFLE:
+ case GROUP_BY_KEY:
+ new Expectations() {
+ {
+ // TODO: restrict the range of counter value
+ destPi.processEvent(event, anyInt);
+ times = 1;
+ }
+ };
+ break;
+ case BROADCAST:
+ new Expectations() {
+ {
+ // TODO: restrict the range of counter value
+ destPi.processEvent(event, anyInt);
+ times = parallelism;
+ }
+ };
+ break;
+ }
+ stream.put(event);
+ }
}
diff --git a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleTopologyTest.java b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleTopologyTest.java
index 2423778..418ad14 100644
--- a/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleTopologyTest.java
+++ b/samoa-local/src/test/java/com/yahoo/labs/samoa/topology/impl/SimpleTopologyTest.java
@@ -31,63 +31,64 @@
* #L%
*/
-
-
import com.yahoo.labs.samoa.core.EntranceProcessor;
import com.yahoo.labs.samoa.topology.EntranceProcessingItem;
/**
* @author Anh Thu Vu
- *
+ *
*/
public class SimpleTopologyTest {
- @Tested private SimpleTopology topology;
-
- @Mocked private SimpleEntranceProcessingItem entrancePi;
- @Mocked private EntranceProcessor entranceProcessor;
-
- @Before
- public void setUp() throws Exception {
- topology = new SimpleTopology("TestTopology");
- }
+ @Tested
+ private SimpleTopology topology;
- @Test
- public void testAddEntrancePi() {
- topology.addEntranceProcessingItem(entrancePi);
-
- Set<EntranceProcessingItem> entrancePIs = topology.getEntranceProcessingItems();
- assertNotNull("Set of entrance PIs is null.",entrancePIs);
- assertEquals("Number of entrance PI in SimpleTopology must be 1",1,entrancePIs.size());
- assertSame("Entrance PI was not set correctly.",entrancePi,entrancePIs.toArray()[0]);
- // TODO: verify that entrance PI is in the set of ProcessingItems
- // Need to access topology's set of PIs (getProcessingItems() method)
- }
-
- @Test
- public void testRun() {
- topology.addEntranceProcessingItem(entrancePi);
-
- new NonStrictExpectations() {
- {
- entrancePi.getProcessor();
- result=entranceProcessor;
-
- }
- };
-
- new Expectations() {
- {
- entranceProcessor.onCreate(anyInt);
- entrancePi.startSendingEvents();
- }
- };
- topology.run();
- }
-
- @Test(expected=IllegalStateException.class)
- public void testRunWithoutEntrancePI() {
- topology.run();
- }
+ @Mocked
+ private SimpleEntranceProcessingItem entrancePi;
+ @Mocked
+ private EntranceProcessor entranceProcessor;
+
+ @Before
+ public void setUp() throws Exception {
+ topology = new SimpleTopology("TestTopology");
+ }
+
+ @Test
+ public void testAddEntrancePi() {
+ topology.addEntranceProcessingItem(entrancePi);
+
+ Set<EntranceProcessingItem> entrancePIs = topology.getEntranceProcessingItems();
+ assertNotNull("Set of entrance PIs is null.", entrancePIs);
+ assertEquals("Number of entrance PI in SimpleTopology must be 1", 1, entrancePIs.size());
+ assertSame("Entrance PI was not set correctly.", entrancePi, entrancePIs.toArray()[0]);
+ // TODO: verify that entrance PI is in the set of ProcessingItems
+ // Need to access topology's set of PIs (getProcessingItems() method)
+ }
+
+ @Test
+ public void testRun() {
+ topology.addEntranceProcessingItem(entrancePi);
+
+ new NonStrictExpectations() {
+ {
+ entrancePi.getProcessor();
+ result = entranceProcessor;
+
+ }
+ };
+
+ new Expectations() {
+ {
+ entranceProcessor.onCreate(anyInt);
+ entrancePi.startSendingEvents();
+ }
+ };
+ topology.run();
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testRunWithoutEntrancePI() {
+ topology.run();
+ }
}
diff --git a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4ComponentFactory.java b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4ComponentFactory.java
index 33299ac..b627416 100644
--- a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4ComponentFactory.java
+++ b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4ComponentFactory.java
@@ -40,58 +40,59 @@
*/
public class S4ComponentFactory implements ComponentFactory {
- public static final Logger logger = LoggerFactory.getLogger(S4ComponentFactory.class);
- protected S4DoTask app;
+ public static final Logger logger = LoggerFactory.getLogger(S4ComponentFactory.class);
+ protected S4DoTask app;
- @Override
- public ProcessingItem createPi(Processor processor, int paralellism) {
- S4ProcessingItem processingItem = new S4ProcessingItem(app);
- // TODO refactor how to set the paralellism level
- processingItem.setParalellismLevel(paralellism);
- processingItem.setProcessor(processor);
+ @Override
+ public ProcessingItem createPi(Processor processor, int paralellism) {
+ S4ProcessingItem processingItem = new S4ProcessingItem(app);
+ // TODO refactor how to set the paralellism level
+ processingItem.setParalellismLevel(paralellism);
+ processingItem.setProcessor(processor);
- return processingItem;
- }
+ return processingItem;
+ }
- @Override
- public ProcessingItem createPi(Processor processor) {
- return this.createPi(processor, 1);
- }
+ @Override
+ public ProcessingItem createPi(Processor processor) {
+ return this.createPi(processor, 1);
+ }
- @Override
- public EntranceProcessingItem createEntrancePi(EntranceProcessor entranceProcessor) {
- // TODO Create source Entry processing item that connects to an external stream
- S4EntranceProcessingItem entrancePi = new S4EntranceProcessingItem(entranceProcessor, app);
- entrancePi.setParallelism(1); // FIXME should not be set to 1 statically
- return entrancePi;
- }
+ @Override
+ public EntranceProcessingItem createEntrancePi(EntranceProcessor entranceProcessor) {
+ // TODO Create source Entry processing item that connects to an external
+ // stream
+ S4EntranceProcessingItem entrancePi = new S4EntranceProcessingItem(entranceProcessor, app);
+ entrancePi.setParallelism(1); // FIXME should not be set to 1 statically
+ return entrancePi;
+ }
- @Override
- public Stream createStream(IProcessingItem sourcePi) {
- S4Stream aStream = new S4Stream(app);
- return aStream;
- }
+ @Override
+ public Stream createStream(IProcessingItem sourcePi) {
+ S4Stream aStream = new S4Stream(app);
+ return aStream;
+ }
- @Override
- public Topology createTopology(String topoName) {
- return new S4Topology(topoName);
- }
+ @Override
+ public Topology createTopology(String topoName) {
+ return new S4Topology(topoName);
+ }
- /**
- * Initialization method.
- *
- * @param evalTask
- */
- public void init(String evalTask) {
- // Task is initiated in the DoTaskApp
- }
+ /**
+ * Initialization method.
+ *
+ * @param evalTask
+ */
+ public void init(String evalTask) {
+ // Task is initiated in the DoTaskApp
+ }
- /**
- * Sets S4 application.
- *
- * @param app
- */
- public void setApp(S4DoTask app) {
- this.app = app;
- }
+ /**
+ * Sets S4 application.
+ *
+ * @param app
+ */
+ public void setApp(S4DoTask app) {
+ this.app = app;
+ }
}
diff --git a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4DoTask.java b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4DoTask.java
index 0f474a4..3691a82 100644
--- a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4DoTask.java
+++ b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4DoTask.java
@@ -56,208 +56,213 @@
*/
final public class S4DoTask extends App {
- private final Logger logger = LoggerFactory.getLogger(S4DoTask.class);
- Task task;
+ private final Logger logger = LoggerFactory.getLogger(S4DoTask.class);
+ Task task;
- @Inject @Named("evalTask") public String evalTask;
+ @Inject
+ @Named("evalTask")
+ public String evalTask;
- public S4DoTask() {
- super();
- }
+ public S4DoTask() {
+ super();
+ }
- /** The engine. */
- protected ComponentFactory componentFactory;
+ /** The engine. */
+ protected ComponentFactory componentFactory;
- /**
- * Gets the factory.
- *
- * @return the factory
- */
- public ComponentFactory getFactory() {
- return componentFactory;
- }
+ /**
+ * Gets the factory.
+ *
+ * @return the factory
+ */
+ public ComponentFactory getFactory() {
+ return componentFactory;
+ }
- /**
- * Sets the factory.
- *
- * @param factory
- * the new factory
- */
- public void setFactory(ComponentFactory factory) {
- this.componentFactory = factory;
- }
+ /**
+ * Sets the factory.
+ *
+ * @param factory
+ * the new factory
+ */
+ public void setFactory(ComponentFactory factory) {
+ this.componentFactory = factory;
+ }
- /*
- * Build the application
- *
- * @see org.apache.s4.core.App#onInit()
- */
- /*
- * (non-Javadoc)
- *
- * @see org.apache.s4.core.App#onInit()
- */
- @Override
- protected void onInit() {
- logger.info("DoTaskApp onInit");
- // ConsoleReporters prints S4 metrics
- // MetricsRegistry mr = new MetricsRegistry();
- //
- // CsvReporter.enable(new File(System.getProperty("user.home")
- // + "/monitor/"), 10, TimeUnit.SECONDS);
- // ConsoleReporter.enable(10, TimeUnit.SECONDS);
- try {
- System.err.println();
- System.err.println(Globals.getWorkbenchInfoString());
- System.err.println();
-
- } catch (Exception ex) {
- ex.printStackTrace();
- }
- S4ComponentFactory factory = new S4ComponentFactory();
- factory.setApp(this);
-
- // logger.debug("LC {}", lc);
-
- // task = TaskProvider.getTask(evalTask);
-
- // EXAMPLE OPTIONS
- // -l Clustream -g Clustream -i 100000 -s (RandomRBFGeneratorEvents -K
- // 5 -N 0.0)
- // String[] args = new String[] {evalTask,"-l", "Clustream","-g",
- // "Clustream", "-i", "100000", "-s", "(RamdomRBFGeneratorsEvents",
- // "-K", "5", "-N", "0.0)"};
- // String[] args = new String[] { evalTask, "-l", "clustream.Clustream",
- // "-g", "clustream.Clustream", "-i", "100000", "-s",
- // "(RandomRBFGeneratorEvents", "-K", "5", "-N", "0.0)" };
- logger.debug("PARAMETERS {}", evalTask);
- // params = params.replace(":", " ");
- List<String> parameters = new ArrayList<String>();
- // parameters.add(evalTask);
- try {
- parameters.addAll(Arrays.asList(URLDecoder.decode(evalTask, "UTF-8").split(" ")));
- } catch (UnsupportedEncodingException ex) {
- ex.printStackTrace();
- }
- String[] args = parameters.toArray(new String[0]);
- Option[] extraOptions = new Option[] {};
- // build a single string by concatenating cli options
- StringBuilder cliString = new StringBuilder();
- for (int i = 0; i < args.length; i++) {
- cliString.append(" ").append(args[i]);
- }
-
- // parse options
- try {
- task = (Task) ClassOption.cliStringToObject(cliString.toString(), Task.class, extraOptions);
- task.setFactory(factory);
- task.init();
- } catch (Exception e) {
- e.printStackTrace();
- }
-
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.s4.core.App#onStart()
- */
- @Override
- protected void onStart() {
- logger.info("Starting DoTaskApp... App Partition [{}]", this.getPartitionId());
- // <<<<<<< HEAD Task doesn't have start in latest storm-impl
- // TODO change the way the app starts
- // if (this.getPartitionId() == 0)
- S4Topology s4topology = (S4Topology) getTask().getTopology();
- S4EntranceProcessingItem epi = (S4EntranceProcessingItem) s4topology.getEntranceProcessingItem();
- while (epi.injectNextEvent())
- // inject events from the EntrancePI
- ;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.s4.core.App#onClose()
- */
- @Override
- protected void onClose() {
- System.out.println("Closing DoTaskApp...");
-
- }
-
- /**
- * Gets the task.
- *
- * @return the task
- */
- public Task getTask() {
- return task;
- }
-
- // These methods are protected in App and can not be accessed from outside.
- // They are
- // called from parallel classifiers and evaluations. Is there a better way
- // to do that?
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.s4.core.App#createPE(java.lang.Class)
- */
- @Override
- public <T extends ProcessingElement> T createPE(Class<T> type) {
- return super.createPE(type);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.s4.core.App#createStream(java.lang.String, org.apache.s4.base.KeyFinder, org.apache.s4.core.ProcessingElement[])
- */
- @Override
- public <T extends Event> Stream<T> createStream(String name, KeyFinder<T> finder, ProcessingElement... processingElements) {
- return super.createStream(name, finder, processingElements);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.s4.core.App#createStream(java.lang.String, org.apache.s4.core.ProcessingElement[])
- */
- @Override
- public <T extends Event> Stream<T> createStream(String name, ProcessingElement... processingElements) {
- return super.createStream(name, processingElements);
- }
-
- // @com.beust.jcommander.Parameters(separators = "=")
- // class Parameters {
+ /*
+ * Build the application
+ *
+ * @see org.apache.s4.core.App#onInit()
+ */
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.s4.core.App#onInit()
+ */
+ @Override
+ protected void onInit() {
+ logger.info("DoTaskApp onInit");
+ // ConsoleReporters prints S4 metrics
+ // MetricsRegistry mr = new MetricsRegistry();
//
- // @Parameter(names={"-lc","-local"}, description="Local clustering method")
- // private String localClustering;
- //
- // @Parameter(names={"-gc","-global"},
- // description="Global clustering method")
- // private String globalClustering;
- //
- // }
- //
- // class ParametersConverter {// implements IStringConverter<String[]> {
- //
- //
- // public String[] convertToArgs(String value) {
- //
- // String[] params = value.split(",");
- // String[] args = new String[params.length*2];
- // for(int i=0; i<params.length ; i++) {
- // args[i] = params[i].split("=")[0];
- // args[i+1] = params[i].split("=")[1];
- // i++;
- // }
- // return args;
- // }
- //
- // }
+ // CsvReporter.enable(new File(System.getProperty("user.home")
+ // + "/monitor/"), 10, TimeUnit.SECONDS);
+ // ConsoleReporter.enable(10, TimeUnit.SECONDS);
+ try {
+ System.err.println();
+ System.err.println(Globals.getWorkbenchInfoString());
+ System.err.println();
+
+ } catch (Exception ex) {
+ ex.printStackTrace();
+ }
+ S4ComponentFactory factory = new S4ComponentFactory();
+ factory.setApp(this);
+
+ // logger.debug("LC {}", lc);
+
+ // task = TaskProvider.getTask(evalTask);
+
+ // EXAMPLE OPTIONS
+ // -l Clustream -g Clustream -i 100000 -s (RandomRBFGeneratorEvents -K
+ // 5 -N 0.0)
+ // String[] args = new String[] {evalTask,"-l", "Clustream","-g",
+ // "Clustream", "-i", "100000", "-s", "(RamdomRBFGeneratorsEvents",
+ // "-K", "5", "-N", "0.0)"};
+ // String[] args = new String[] { evalTask, "-l", "clustream.Clustream",
+ // "-g", "clustream.Clustream", "-i", "100000", "-s",
+ // "(RandomRBFGeneratorEvents", "-K", "5", "-N", "0.0)" };
+ logger.debug("PARAMETERS {}", evalTask);
+ // params = params.replace(":", " ");
+ List<String> parameters = new ArrayList<String>();
+ // parameters.add(evalTask);
+ try {
+ parameters.addAll(Arrays.asList(URLDecoder.decode(evalTask, "UTF-8").split(" ")));
+ } catch (UnsupportedEncodingException ex) {
+ ex.printStackTrace();
+ }
+ String[] args = parameters.toArray(new String[0]);
+ Option[] extraOptions = new Option[] {};
+ // build a single string by concatenating cli options
+ StringBuilder cliString = new StringBuilder();
+ for (int i = 0; i < args.length; i++) {
+ cliString.append(" ").append(args[i]);
+ }
+
+ // parse options
+ try {
+ task = (Task) ClassOption.cliStringToObject(cliString.toString(), Task.class, extraOptions);
+ task.setFactory(factory);
+ task.init();
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.s4.core.App#onStart()
+ */
+ @Override
+ protected void onStart() {
+ logger.info("Starting DoTaskApp... App Partition [{}]", this.getPartitionId());
+ // <<<<<<< HEAD Task doesn't have start in latest storm-impl
+ // TODO change the way the app starts
+ // if (this.getPartitionId() == 0)
+ S4Topology s4topology = (S4Topology) getTask().getTopology();
+ S4EntranceProcessingItem epi = (S4EntranceProcessingItem) s4topology.getEntranceProcessingItem();
+ while (epi.injectNextEvent())
+ // inject events from the EntrancePI
+ ;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.s4.core.App#onClose()
+ */
+ @Override
+ protected void onClose() {
+ System.out.println("Closing DoTaskApp...");
+
+ }
+
+ /**
+ * Gets the task.
+ *
+ * @return the task
+ */
+ public Task getTask() {
+ return task;
+ }
+
+ // These methods are protected in App and can not be accessed from outside.
+ // They are
+ // called from parallel classifiers and evaluations. Is there a better way
+ // to do that?
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.s4.core.App#createPE(java.lang.Class)
+ */
+ @Override
+ public <T extends ProcessingElement> T createPE(Class<T> type) {
+ return super.createPE(type);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.s4.core.App#createStream(java.lang.String,
+ * org.apache.s4.base.KeyFinder, org.apache.s4.core.ProcessingElement[])
+ */
+ @Override
+ public <T extends Event> Stream<T> createStream(String name, KeyFinder<T> finder,
+ ProcessingElement... processingElements) {
+ return super.createStream(name, finder, processingElements);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.s4.core.App#createStream(java.lang.String,
+ * org.apache.s4.core.ProcessingElement[])
+ */
+ @Override
+ public <T extends Event> Stream<T> createStream(String name, ProcessingElement... processingElements) {
+ return super.createStream(name, processingElements);
+ }
+
+ // @com.beust.jcommander.Parameters(separators = "=")
+ // class Parameters {
+ //
+ // @Parameter(names={"-lc","-local"}, description="Local clustering method")
+ // private String localClustering;
+ //
+ // @Parameter(names={"-gc","-global"},
+ // description="Global clustering method")
+ // private String globalClustering;
+ //
+ // }
+ //
+ // class ParametersConverter {// implements IStringConverter<String[]> {
+ //
+ //
+ // public String[] convertToArgs(String value) {
+ //
+ // String[] params = value.split(",");
+ // String[] args = new String[params.length*2];
+ // for(int i=0; i<params.length ; i++) {
+ // args[i] = params[i].split("=")[0];
+ // args[i+1] = params[i].split("=")[1];
+ // i++;
+ // }
+ // return args;
+ // }
+ //
+ // }
}
diff --git a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4EntranceProcessingItem.java b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4EntranceProcessingItem.java
index 2b0c595..6f374fa 100644
--- a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4EntranceProcessingItem.java
+++ b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4EntranceProcessingItem.java
@@ -32,89 +32,89 @@
public class S4EntranceProcessingItem extends ProcessingElement implements EntranceProcessingItem {
- private EntranceProcessor entranceProcessor;
- // private S4DoTask app;
- private int parallelism;
- protected Stream outputStream;
+ private EntranceProcessor entranceProcessor;
+ // private S4DoTask app;
+ private int parallelism;
+ protected Stream outputStream;
- /**
- * Constructor of an S4 entrance processing item.
- *
- * @param app
- * : S4 application
- */
- public S4EntranceProcessingItem(EntranceProcessor entranceProcessor, App app) {
- super(app);
- this.entranceProcessor = entranceProcessor;
- // this.app = (S4DoTask) app;
- // this.setSingleton(true);
+ /**
+ * Constructor of an S4 entrance processing item.
+ *
+ * @param app
+ * : S4 application
+ */
+ public S4EntranceProcessingItem(EntranceProcessor entranceProcessor, App app) {
+ super(app);
+ this.entranceProcessor = entranceProcessor;
+ // this.app = (S4DoTask) app;
+ // this.setSingleton(true);
+ }
+
+ public void setParallelism(int parallelism) {
+ this.parallelism = parallelism;
+ }
+
+ public int getParallelism() {
+ return this.parallelism;
+ }
+
+ @Override
+ public EntranceProcessor getProcessor() {
+ return this.entranceProcessor;
+ }
+
+ //
+ // @Override
+ // public void put(Instance inst) {
+ // // do nothing
+ // // may not needed
+ // }
+
+ @Override
+ protected void onCreate() {
+ // was commented
+ if (this.entranceProcessor != null) {
+ // TODO revisit if we need to change it to a clone() call
+ this.entranceProcessor = (EntranceProcessor) this.entranceProcessor.newProcessor(this.entranceProcessor);
+ this.entranceProcessor.onCreate(Integer.parseInt(getId()));
}
+ }
- public void setParallelism(int parallelism) {
- this.parallelism = parallelism;
- }
+ @Override
+ protected void onRemove() {
+ // do nothing
+ }
- public int getParallelism() {
- return this.parallelism;
- }
+ //
+ // /**
+ // * Sets the entrance processing item processor.
+ // *
+ // * @param processor
+ // */
+ // public void setProcessor(Processor processor) {
+ // this.entranceProcessor = processor;
+ // }
- @Override
- public EntranceProcessor getProcessor() {
- return this.entranceProcessor;
- }
+ @Override
+ public void setName(String name) {
+ super.setName(name);
+ }
- //
- // @Override
- // public void put(Instance inst) {
- // // do nothing
- // // may not needed
- // }
+ @Override
+ public EntranceProcessingItem setOutputStream(Stream stream) {
+ if (this.outputStream != null)
+ throw new IllegalStateException("Output stream for an EntrancePI sohuld be initialized only once");
+ this.outputStream = stream;
+ return this;
+ }
- @Override
- protected void onCreate() {
- // was commented
- if (this.entranceProcessor != null) {
- // TODO revisit if we need to change it to a clone() call
- this.entranceProcessor = (EntranceProcessor) this.entranceProcessor.newProcessor(this.entranceProcessor);
- this.entranceProcessor.onCreate(Integer.parseInt(getId()));
- }
- }
-
- @Override
- protected void onRemove() {
- // do nothing
- }
-
- //
- // /**
- // * Sets the entrance processing item processor.
- // *
- // * @param processor
- // */
- // public void setProcessor(Processor processor) {
- // this.entranceProcessor = processor;
- // }
-
- @Override
- public void setName(String name) {
- super.setName(name);
- }
-
- @Override
- public EntranceProcessingItem setOutputStream(Stream stream) {
- if (this.outputStream != null)
- throw new IllegalStateException("Output stream for an EntrancePI sohuld be initialized only once");
- this.outputStream = stream;
- return this;
- }
-
- public boolean injectNextEvent() {
- if (entranceProcessor.hasNext()) {
- ContentEvent nextEvent = this.entranceProcessor.nextEvent();
- outputStream.put(nextEvent);
- return entranceProcessor.hasNext();
- } else
- return false;
- // return !nextEvent.isLastEvent();
- }
+ public boolean injectNextEvent() {
+ if (entranceProcessor.hasNext()) {
+ ContentEvent nextEvent = this.entranceProcessor.nextEvent();
+ outputStream.put(nextEvent);
+ return entranceProcessor.hasNext();
+ } else
+ return false;
+ // return !nextEvent.isLastEvent();
+ }
}
diff --git a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Event.java b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Event.java
index 8f8ad9f..62c623c 100644
--- a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Event.java
+++ b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Event.java
@@ -36,55 +36,57 @@
@Immutable
final public class S4Event extends Event {
- private String key;
-
- public String getKey() {
- return key;
- }
+ private String key;
- public void setKey(String key) {
- this.key = key;
- }
+ public String getKey() {
+ return key;
+ }
- /** The content event. */
- private ContentEvent contentEvent;
-
- /**
- * Instantiates a new instance event.
- */
- public S4Event() {
- // Needed for serialization of kryo
- }
+ public void setKey(String key) {
+ this.key = key;
+ }
- /**
- * Instantiates a new instance event.
- *
- * @param contentEvent the content event
- */
- public S4Event(ContentEvent contentEvent) {
- if (contentEvent != null) {
- this.contentEvent = contentEvent;
- this.key = contentEvent.getKey();
-
- }
- }
+ /** The content event. */
+ private ContentEvent contentEvent;
- /**
- * Gets the content event.
- *
- * @return the content event
- */
- public ContentEvent getContentEvent() {
- return contentEvent;
- }
+ /**
+ * Instantiates a new instance event.
+ */
+ public S4Event() {
+ // Needed for serialization of kryo
+ }
- /**
- * Sets the content event.
- *
- * @param contentEvent the new content event
- */
- public void setContentEvent(ContentEvent contentEvent) {
- this.contentEvent = contentEvent;
- }
+ /**
+ * Instantiates a new instance event.
+ *
+ * @param contentEvent
+ * the content event
+ */
+ public S4Event(ContentEvent contentEvent) {
+ if (contentEvent != null) {
+ this.contentEvent = contentEvent;
+ this.key = contentEvent.getKey();
+
+ }
+ }
+
+ /**
+ * Gets the content event.
+ *
+ * @return the content event
+ */
+ public ContentEvent getContentEvent() {
+ return contentEvent;
+ }
+
+ /**
+ * Sets the content event.
+ *
+ * @param contentEvent
+ * the new content event
+ */
+ public void setContentEvent(ContentEvent contentEvent) {
+ this.contentEvent = contentEvent;
+ }
}
diff --git a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4ProcessingItem.java b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4ProcessingItem.java
index 1351159..da5644d 100644
--- a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4ProcessingItem.java
+++ b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4ProcessingItem.java
@@ -35,154 +35,154 @@
import com.yahoo.labs.samoa.topology.Stream;
/**
- * S4 Platform platform specific processing item, inherits from S4 ProcessinElemnt.
+ * S4 Platform platform specific processing item, inherits from S4
+ * ProcessinElemnt.
*
* @author severien
- *
+ *
*/
public class S4ProcessingItem extends ProcessingElement implements
- ProcessingItem {
+ ProcessingItem {
- public static final Logger logger = LoggerFactory
- .getLogger(S4ProcessingItem.class);
+ public static final Logger logger = LoggerFactory
+ .getLogger(S4ProcessingItem.class);
- private Processor processor;
- private int paralellismLevel;
- private S4DoTask app;
+ private Processor processor;
+ private int paralellismLevel;
+ private S4DoTask app;
- private static final String NAME="PROCESSING-ITEM-";
- private static int OBJ_COUNTER=0;
-
- /**
- * Constructor of S4 ProcessingItem.
- *
- * @param app : S4 application
- */
- public S4ProcessingItem(App app) {
- super(app);
- super.setName(NAME+OBJ_COUNTER);
- OBJ_COUNTER++;
- this.app = (S4DoTask) app;
- this.paralellismLevel = 1;
- }
+ private static final String NAME = "PROCESSING-ITEM-";
+ private static int OBJ_COUNTER = 0;
- @Override
- public String getName() {
- return super.getName();
- }
-
- /**
- * Gets processing item paralellism level.
- *
- * @return int
- */
- public int getParalellismLevel() {
- return paralellismLevel;
- }
+ /**
+ * Constructor of S4 ProcessingItem.
+ *
+ * @param app
+ * : S4 application
+ */
+ public S4ProcessingItem(App app) {
+ super(app);
+ super.setName(NAME + OBJ_COUNTER);
+ OBJ_COUNTER++;
+ this.app = (S4DoTask) app;
+ this.paralellismLevel = 1;
+ }
- /**
- * Sets processing item paralellism level.
- *
- * @param paralellismLevel
- */
- public void setParalellismLevel(int paralellismLevel) {
- this.paralellismLevel = paralellismLevel;
- }
+ @Override
+ public String getName() {
+ return super.getName();
+ }
- /**
- * onEvent method.
- *
- * @param event
- */
- public void onEvent(S4Event event) {
- if (processor.process(event.getContentEvent()) == true) {
- close();
- }
- }
+ /**
+ * Gets processing item paralellism level.
+ *
+ * @return int
+ */
+ public int getParalellismLevel() {
+ return paralellismLevel;
+ }
- /**
- * Sets S4 processing item processor.
- *
- * @param processor
- */
- public void setProcessor(Processor processor) {
- this.processor = processor;
- }
+ /**
+ * Sets processing item paralellism level.
+ *
+ * @param paralellismLevel
+ */
+ public void setParalellismLevel(int paralellismLevel) {
+ this.paralellismLevel = paralellismLevel;
+ }
- // Methods from ProcessingItem
- @Override
- public Processor getProcessor() {
- return processor;
- }
+ /**
+ * onEvent method.
+ *
+ * @param event
+ */
+ public void onEvent(S4Event event) {
+ if (processor.process(event.getContentEvent()) == true) {
+ close();
+ }
+ }
- /**
- * KeyFinder sets the keys for a specific event.
- *
- * @return KeyFinder
- */
- private KeyFinder<S4Event> getKeyFinder() {
- KeyFinder<S4Event> keyFinder = new KeyFinder<S4Event>() {
- @Override
- public List<String> get(S4Event s4event) {
- List<String> results = new ArrayList<String>();
- results.add(s4event.getKey());
- return results;
- }
- };
+ /**
+ * Sets S4 processing item processor.
+ *
+ * @param processor
+ */
+ public void setProcessor(Processor processor) {
+ this.processor = processor;
+ }
- return keyFinder;
- }
-
-
- @Override
- public ProcessingItem connectInputAllStream(Stream inputStream) {
+ // Methods from ProcessingItem
+ @Override
+ public Processor getProcessor() {
+ return processor;
+ }
- S4Stream stream = (S4Stream) inputStream;
- stream.setParallelism(this.paralellismLevel);
- stream.addStream(inputStream.getStreamId(),
- getKeyFinder(), this, S4Stream.BROADCAST);
- return this;
- }
+ /**
+ * KeyFinder sets the keys for a specific event.
+ *
+ * @return KeyFinder
+ */
+ private KeyFinder<S4Event> getKeyFinder() {
+ KeyFinder<S4Event> keyFinder = new KeyFinder<S4Event>() {
+ @Override
+ public List<String> get(S4Event s4event) {
+ List<String> results = new ArrayList<String>();
+ results.add(s4event.getKey());
+ return results;
+ }
+ };
-
- @Override
- public ProcessingItem connectInputKeyStream(Stream inputStream) {
+ return keyFinder;
+ }
- S4Stream stream = (S4Stream) inputStream;
- stream.setParallelism(this.paralellismLevel);
- stream.addStream(inputStream.getStreamId(),
- getKeyFinder(), this,S4Stream.GROUP_BY_KEY);
+ @Override
+ public ProcessingItem connectInputAllStream(Stream inputStream) {
- return this;
- }
-
- @Override
- public ProcessingItem connectInputShuffleStream(Stream inputStream) {
- S4Stream stream = (S4Stream) inputStream;
- stream.setParallelism(this.paralellismLevel);
- stream.addStream(inputStream.getStreamId(),
- getKeyFinder(), this,S4Stream.SHUFFLE);
+ S4Stream stream = (S4Stream) inputStream;
+ stream.setParallelism(this.paralellismLevel);
+ stream.addStream(inputStream.getStreamId(),
+ getKeyFinder(), this, S4Stream.BROADCAST);
+ return this;
+ }
- return this;
- }
+ @Override
+ public ProcessingItem connectInputKeyStream(Stream inputStream) {
- // Methods from ProcessingElement
- @Override
- protected void onCreate() {
- logger.debug("PE ID {}", getId());
- if (this.processor != null) {
- this.processor = this.processor.newProcessor(this.processor);
- this.processor.onCreate(Integer.parseInt(getId()));
- }
- }
+ S4Stream stream = (S4Stream) inputStream;
+ stream.setParallelism(this.paralellismLevel);
+ stream.addStream(inputStream.getStreamId(),
+ getKeyFinder(), this, S4Stream.GROUP_BY_KEY);
- @Override
- protected void onRemove() {
- // do nothing
- }
+ return this;
+ }
- @Override
- public int getParallelism() {
- return this.paralellismLevel;
- }
+ @Override
+ public ProcessingItem connectInputShuffleStream(Stream inputStream) {
+ S4Stream stream = (S4Stream) inputStream;
+ stream.setParallelism(this.paralellismLevel);
+ stream.addStream(inputStream.getStreamId(),
+ getKeyFinder(), this, S4Stream.SHUFFLE);
+
+ return this;
+ }
+
+ // Methods from ProcessingElement
+ @Override
+ protected void onCreate() {
+ logger.debug("PE ID {}", getId());
+ if (this.processor != null) {
+ this.processor = this.processor.newProcessor(this.processor);
+ this.processor.onCreate(Integer.parseInt(getId()));
+ }
+ }
+
+ @Override
+ protected void onRemove() {
+ // do nothing
+ }
+
+ @Override
+ public int getParallelism() {
+ return this.paralellismLevel;
+ }
}
diff --git a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Stream.java b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Stream.java
index 78a3266..67a1385 100644
--- a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Stream.java
+++ b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Stream.java
@@ -35,151 +35,151 @@
* S4 Platform specific stream.
*
* @author severien
- *
+ *
*/
public class S4Stream extends AbstractStream {
- public static final int SHUFFLE = 0;
- public static final int GROUP_BY_KEY = 1;
- public static final int BROADCAST = 2;
+ public static final int SHUFFLE = 0;
+ public static final int GROUP_BY_KEY = 1;
+ public static final int BROADCAST = 2;
- private static final Logger logger = LoggerFactory.getLogger(S4Stream.class);
+ private static final Logger logger = LoggerFactory.getLogger(S4Stream.class);
- private S4DoTask app;
- private int processingItemParalellism;
- private int shuffleCounter;
+ private S4DoTask app;
+ private int processingItemParalellism;
+ private int shuffleCounter;
- private static final String NAME = "STREAM-";
- private static int OBJ_COUNTER = 0;
-
- /* The stream list */
- public List<StreamType> streams;
+ private static final String NAME = "STREAM-";
+ private static int OBJ_COUNTER = 0;
- public S4Stream(S4DoTask app) {
- super();
- this.app = app;
- this.processingItemParalellism = 1;
- this.shuffleCounter = 0;
- this.streams = new ArrayList<StreamType>();
- this.setStreamId(NAME+OBJ_COUNTER);
- OBJ_COUNTER++;
- }
-
- public S4Stream(S4DoTask app, S4ProcessingItem pi) {
- super();
- this.app = app;
- this.processingItemParalellism = 1;
- this.shuffleCounter = 0;
- this.streams = new ArrayList<StreamType>();
- this.setStreamId(NAME+OBJ_COUNTER);
- OBJ_COUNTER++;
-
- }
+ /* The stream list */
+ public List<StreamType> streams;
- /**
- *
- * @return
- */
- public int getParallelism() {
- return processingItemParalellism;
- }
+ public S4Stream(S4DoTask app) {
+ super();
+ this.app = app;
+ this.processingItemParalellism = 1;
+ this.shuffleCounter = 0;
+ this.streams = new ArrayList<StreamType>();
+ this.setStreamId(NAME + OBJ_COUNTER);
+ OBJ_COUNTER++;
+ }
- public void setParallelism(int parallelism) {
- this.processingItemParalellism = parallelism;
- }
+ public S4Stream(S4DoTask app, S4ProcessingItem pi) {
+ super();
+ this.app = app;
+ this.processingItemParalellism = 1;
+ this.shuffleCounter = 0;
+ this.streams = new ArrayList<StreamType>();
+ this.setStreamId(NAME + OBJ_COUNTER);
+ OBJ_COUNTER++;
- public void addStream(String streamID, KeyFinder<S4Event> finder,
- S4ProcessingItem pi, int type) {
- String streamName = streamID +"_"+pi.getName();
- org.apache.s4.core.Stream<S4Event> stream = this.app.createStream(
- streamName, pi);
- stream.setName(streamName);
- logger.debug("Stream name S4Stream {}", streamName);
- if (finder != null)
- stream.setKey(finder);
- this.streams.add(new StreamType(stream, type));
+ }
- }
+ /**
+ *
+ * @return
+ */
+ public int getParallelism() {
+ return processingItemParalellism;
+ }
- @Override
- public void put(ContentEvent event) {
+ public void setParallelism(int parallelism) {
+ this.processingItemParalellism = parallelism;
+ }
- for (int i = 0; i < streams.size(); i++) {
+ public void addStream(String streamID, KeyFinder<S4Event> finder,
+ S4ProcessingItem pi, int type) {
+ String streamName = streamID + "_" + pi.getName();
+ org.apache.s4.core.Stream<S4Event> stream = this.app.createStream(
+ streamName, pi);
+ stream.setName(streamName);
+ logger.debug("Stream name S4Stream {}", streamName);
+ if (finder != null)
+ stream.setKey(finder);
+ this.streams.add(new StreamType(stream, type));
- switch (streams.get(i).getType()) {
- case SHUFFLE:
- S4Event s4event = new S4Event(event);
- s4event.setStreamId(streams.get(i).getStream().getName());
- if(getParallelism() == 1) {
- s4event.setKey("0");
- }else {
- s4event.setKey(Integer.toString(shuffleCounter));
- }
- streams.get(i).getStream().put(s4event);
- shuffleCounter++;
- if (shuffleCounter >= (getParallelism())) {
- shuffleCounter = 0;
- }
-
- break;
+ }
- case GROUP_BY_KEY:
- S4Event s4event1 = new S4Event(event);
- s4event1.setStreamId(streams.get(i).getStream().getName());
- HashCodeBuilder hb = new HashCodeBuilder();
- hb.append(event.getKey());
- String key = Integer.toString(hb.build() % getParallelism());
- s4event1.setKey(key);
- streams.get(i).getStream().put(s4event1);
- break;
-
- case BROADCAST:
- for (int p = 0; p < this.getParallelism(); p++) {
- S4Event s4event2 = new S4Event(event);
- s4event2.setStreamId(streams.get(i).getStream().getName());
- s4event2.setKey(Integer.toString(p));
- streams.get(i).getStream().put(s4event2);
- }
- break;
+ @Override
+ public void put(ContentEvent event) {
- default:
- break;
- }
+ for (int i = 0; i < streams.size(); i++) {
-
- }
+ switch (streams.get(i).getType()) {
+ case SHUFFLE:
+ S4Event s4event = new S4Event(event);
+ s4event.setStreamId(streams.get(i).getStream().getName());
+ if (getParallelism() == 1) {
+ s4event.setKey("0");
+ } else {
+ s4event.setKey(Integer.toString(shuffleCounter));
+ }
+ streams.get(i).getStream().put(s4event);
+ shuffleCounter++;
+ if (shuffleCounter >= (getParallelism())) {
+ shuffleCounter = 0;
+ }
- }
+ break;
- /**
- * Subclass for definig stream connection type
- * @author severien
- *
- */
- class StreamType {
- org.apache.s4.core.Stream<S4Event> stream;
- int type;
+ case GROUP_BY_KEY:
+ S4Event s4event1 = new S4Event(event);
+ s4event1.setStreamId(streams.get(i).getStream().getName());
+ HashCodeBuilder hb = new HashCodeBuilder();
+ hb.append(event.getKey());
+ String key = Integer.toString(hb.build() % getParallelism());
+ s4event1.setKey(key);
+ streams.get(i).getStream().put(s4event1);
+ break;
- public StreamType(org.apache.s4.core.Stream<S4Event> s, int t) {
- this.stream = s;
- this.type = t;
- }
+ case BROADCAST:
+ for (int p = 0; p < this.getParallelism(); p++) {
+ S4Event s4event2 = new S4Event(event);
+ s4event2.setStreamId(streams.get(i).getStream().getName());
+ s4event2.setKey(Integer.toString(p));
+ streams.get(i).getStream().put(s4event2);
+ }
+ break;
- public org.apache.s4.core.Stream<S4Event> getStream() {
- return stream;
- }
+ default:
+ break;
+ }
- public void setStream(org.apache.s4.core.Stream<S4Event> stream) {
- this.stream = stream;
- }
+ }
- public int getType() {
- return type;
- }
+ }
- public void setType(int type) {
- this.type = type;
- }
+ /**
+ * Subclass for definig stream connection type
+ *
+ * @author severien
+ *
+ */
+ class StreamType {
+ org.apache.s4.core.Stream<S4Event> stream;
+ int type;
- }
+ public StreamType(org.apache.s4.core.Stream<S4Event> s, int t) {
+ this.stream = s;
+ this.type = t;
+ }
+
+ public org.apache.s4.core.Stream<S4Event> getStream() {
+ return stream;
+ }
+
+ public void setStream(org.apache.s4.core.Stream<S4Event> stream) {
+ this.stream = stream;
+ }
+
+ public int getType() {
+ return type;
+ }
+
+ public void setType(int type) {
+ this.type = type;
+ }
+
+ }
}
diff --git a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Submitter.java b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Submitter.java
index c7ef92c..cf5a9b3 100644
--- a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Submitter.java
+++ b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Submitter.java
@@ -44,103 +44,102 @@
public class S4Submitter implements ISubmitter {
- private static Logger logger = LoggerFactory.getLogger(S4Submitter.class);
+ private static Logger logger = LoggerFactory.getLogger(S4Submitter.class);
- @Override
- public void deployTask(Task task) {
- // TODO: Get application FROM HTTP server
- // TODO: Initializa a http server to serve the app package
-
- String appURIString = null;
-// File app = new File(System.getProperty("user.dir")
-// + "/src/site/dist/SAMOA-S4-0.1-dist.jar");
-
- // TODO: String app url http://localhost:8000/SAMOA-S4-0.1-dist.jar
- try {
- URL appURL = new URL("http://localhost:8000/SAMOA-S4-0.1.jar");
- appURIString = appURL.toString();
- } catch (MalformedURLException e1) {
- e1.printStackTrace();
- }
-
-// try {
-// appURIString = app.toURI().toURL().toString();
-// } catch (MalformedURLException e) {
-// e.printStackTrace();
-// }
- if (task == null) {
- logger.error("Can't execute since evaluation task is not set!");
- return;
- } else {
- logger.info("Deploying SAMOA S4 task [{}] from location [{}]. ",
- task.getClass().getSimpleName(), appURIString);
- }
+ @Override
+ public void deployTask(Task task) {
+ // TODO: Get application FROM HTTP server
+ // TODO: Initializa a http server to serve the app package
- String[] args = { "-c=testCluster2",
- "-appClass=" + S4DoTask.class.getName(),
- "-appName=" + "samoaApp",
- "-p=evalTask=" + task.getClass().getSimpleName(),
- "-zk=localhost:2181", "-s4r=" + appURIString , "-emc=" + SamoaSerializerModule.class.getName()};
- // "-emc=" + S4MOAModule.class.getName(),
- // "@" +
- // Resources.getResource("s4moa.properties").getFile(),
+ String appURIString = null;
+ // File app = new File(System.getProperty("user.dir")
+ // + "/src/site/dist/SAMOA-S4-0.1-dist.jar");
- S4Config s4config = new S4Config();
- JCommander jc = new JCommander(s4config);
- jc.parse(args);
+ // TODO: String app url http://localhost:8000/SAMOA-S4-0.1-dist.jar
+ try {
+ URL appURL = new URL("http://localhost:8000/SAMOA-S4-0.1.jar");
+ appURIString = appURL.toString();
+ } catch (MalformedURLException e1) {
+ e1.printStackTrace();
+ }
- Map<String, String> namedParameters = new HashMap<String, String>();
- for (String parameter : s4config.namedParameters) {
- String[] param = parameter.split("=");
- namedParameters.put(param[0], param[1]);
- }
+ // try {
+ // appURIString = app.toURI().toURL().toString();
+ // } catch (MalformedURLException e) {
+ // e.printStackTrace();
+ // }
+ if (task == null) {
+ logger.error("Can't execute since evaluation task is not set!");
+ return;
+ } else {
+ logger.info("Deploying SAMOA S4 task [{}] from location [{}]. ",
+ task.getClass().getSimpleName(), appURIString);
+ }
- AppConfig config = new AppConfig.Builder()
- .appClassName(s4config.appClass).appName(s4config.appName)
- .appURI(s4config.appURI).namedParameters(namedParameters)
- .build();
+ String[] args = { "-c=testCluster2",
+ "-appClass=" + S4DoTask.class.getName(),
+ "-appName=" + "samoaApp",
+ "-p=evalTask=" + task.getClass().getSimpleName(),
+ "-zk=localhost:2181", "-s4r=" + appURIString, "-emc=" + SamoaSerializerModule.class.getName() };
+ // "-emc=" + S4MOAModule.class.getName(),
+ // "@" +
+ // Resources.getResource("s4moa.properties").getFile(),
- DeploymentUtils.initAppConfig(config, s4config.clusterName, true,
- s4config.zkString);
+ S4Config s4config = new S4Config();
+ JCommander jc = new JCommander(s4config);
+ jc.parse(args);
- System.out.println("Suposedly deployed on S4");
- }
+ Map<String, String> namedParameters = new HashMap<String, String>();
+ for (String parameter : s4config.namedParameters) {
+ String[] param = parameter.split("=");
+ namedParameters.put(param[0], param[1]);
+ }
-
- public void initHTTPServer() {
-
- }
-
- @Parameters(separators = "=")
- public static class S4Config {
+ AppConfig config = new AppConfig.Builder()
+ .appClassName(s4config.appClass).appName(s4config.appName)
+ .appURI(s4config.appURI).namedParameters(namedParameters)
+ .build();
- @Parameter(names = { "-c", "-cluster" }, description = "Cluster name", required = true)
- String clusterName = null;
+ DeploymentUtils.initAppConfig(config, s4config.clusterName, true,
+ s4config.zkString);
- @Parameter(names = "-appClass", description = "Main App class", required = false)
- String appClass = null;
+ System.out.println("Suposedly deployed on S4");
+ }
- @Parameter(names = "-appName", description = "Application name", required = false)
- String appName = null;
+ public void initHTTPServer() {
- @Parameter(names = "-s4r", description = "Application URI", required = false)
- String appURI = null;
+ }
- @Parameter(names = "-zk", description = "ZooKeeper connection string", required = false)
- String zkString = null;
+ @Parameters(separators = "=")
+ public static class S4Config {
- @Parameter(names = { "-extraModulesClasses", "-emc" }, description = "Comma-separated list of additional configuration modules (they will be instantiated through their constructor without arguments).", required = false)
- List<String> extraModules = new ArrayList<String>();
+ @Parameter(names = { "-c", "-cluster" }, description = "Cluster name", required = true)
+ String clusterName = null;
- @Parameter(names = { "-p", "-namedStringParameters" }, description = "Comma-separated list of inline configuration "
- + "parameters, taking precedence over homonymous configuration parameters from configuration files. "
- + "Syntax: '-p=name1=value1,name2=value2 '", required = false, converter = ParsingUtils.InlineConfigParameterConverter.class)
- List<String> namedParameters = new ArrayList<String>();
+ @Parameter(names = "-appClass", description = "Main App class", required = false)
+ String appClass = null;
- }
+ @Parameter(names = "-appName", description = "Application name", required = false)
+ String appName = null;
- @Override
- public void setLocal(boolean bool) {
- // TODO S4 works the same for local and distributed environments
- }
+ @Parameter(names = "-s4r", description = "Application URI", required = false)
+ String appURI = null;
+
+ @Parameter(names = "-zk", description = "ZooKeeper connection string", required = false)
+ String zkString = null;
+
+ @Parameter(names = { "-extraModulesClasses", "-emc" }, description = "Comma-separated list of additional configuration modules (they will be instantiated through their constructor without arguments).", required = false)
+ List<String> extraModules = new ArrayList<String>();
+
+ @Parameter(names = { "-p", "-namedStringParameters" }, description = "Comma-separated list of inline configuration "
+ + "parameters, taking precedence over homonymous configuration parameters from configuration files. "
+ + "Syntax: '-p=name1=value1,name2=value2 '", required = false, converter = ParsingUtils.InlineConfigParameterConverter.class)
+ List<String> namedParameters = new ArrayList<String>();
+
+ }
+
+ @Override
+ public void setLocal(boolean bool) {
+ // TODO S4 works the same for local and distributed environments
+ }
}
diff --git a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Topology.java b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Topology.java
index 6bef0e8..2f7661d 100644
--- a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Topology.java
+++ b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/S4Topology.java
@@ -24,38 +24,40 @@
import com.yahoo.labs.samoa.topology.AbstractTopology;
public class S4Topology extends AbstractTopology {
-
- // CASEY: it seems evaluationTask is not used.
- // Remove it for now
-
-// private String _evaluationTask;
-// S4Topology(String topoName, String evalTask) {
-// super(topoName);
-// }
-//
-// S4Topology(String topoName) {
-// this(topoName, null);
-// }
+ // CASEY: it seems evaluationTask is not used.
+ // Remove it for now
-// @Override
-// public void setEvaluationTask(String evalTask) {
-// _evaluationTask = evalTask;
-// }
-//
-// @Override
-// public String getEvaluationTask() {
-// return _evaluationTask;
-// }
-
- S4Topology(String topoName) {
- super(topoName);
- }
-
- protected EntranceProcessingItem getEntranceProcessingItem() {
- if (this.getEntranceProcessingItems() == null) return null;
- if (this.getEntranceProcessingItems().size() < 1) return null;
- // TODO: support multiple entrance PIs
- return (EntranceProcessingItem)this.getEntranceProcessingItems().toArray()[0];
- }
+ // private String _evaluationTask;
+
+ // S4Topology(String topoName, String evalTask) {
+ // super(topoName);
+ // }
+ //
+ // S4Topology(String topoName) {
+ // this(topoName, null);
+ // }
+
+ // @Override
+ // public void setEvaluationTask(String evalTask) {
+ // _evaluationTask = evalTask;
+ // }
+ //
+ // @Override
+ // public String getEvaluationTask() {
+ // return _evaluationTask;
+ // }
+
+ S4Topology(String topoName) {
+ super(topoName);
+ }
+
+ protected EntranceProcessingItem getEntranceProcessingItem() {
+ if (this.getEntranceProcessingItems() == null)
+ return null;
+ if (this.getEntranceProcessingItems().size() < 1)
+ return null;
+ // TODO: support multiple entrance PIs
+ return (EntranceProcessingItem) this.getEntranceProcessingItems().toArray()[0];
+ }
}
diff --git a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSerializer.java b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSerializer.java
index 4ae2296..61648e6 100644
--- a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSerializer.java
+++ b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSerializer.java
@@ -32,68 +32,69 @@
import com.yahoo.labs.samoa.learners.classifiers.trees.AttributeContentEvent;
import com.yahoo.labs.samoa.learners.classifiers.trees.ComputeContentEvent;
-public class SamoaSerializer implements SerializerDeserializer{
+public class SamoaSerializer implements SerializerDeserializer {
- private ThreadLocal<Kryo> kryoThreadLocal;
- private ThreadLocal<Output> outputThreadLocal;
+ private ThreadLocal<Kryo> kryoThreadLocal;
+ private ThreadLocal<Output> outputThreadLocal;
- private int initialBufferSize = 2048;
- private int maxBufferSize = 256 * 1024;
+ private int initialBufferSize = 2048;
+ private int maxBufferSize = 256 * 1024;
- public void setMaxBufferSize(int maxBufferSize) {
- this.maxBufferSize = maxBufferSize;
+ public void setMaxBufferSize(int maxBufferSize) {
+ this.maxBufferSize = maxBufferSize;
+ }
+
+ /**
+ *
+ * @param classLoader
+ * classloader able to handle classes to serialize/deserialize. For
+ * instance, application-level events can only be handled by the
+ * application classloader.
+ */
+ @Inject
+ public SamoaSerializer(@Assisted final ClassLoader classLoader) {
+ kryoThreadLocal = new ThreadLocal<Kryo>() {
+
+ @Override
+ protected Kryo initialValue() {
+ Kryo kryo = new Kryo();
+ kryo.setClassLoader(classLoader);
+ kryo.register(AttributeContentEvent.class, new AttributeContentEvent.AttributeCEFullPrecSerializer());
+ kryo.register(ComputeContentEvent.class, new ComputeContentEvent.ComputeCEFullPrecSerializer());
+ kryo.setRegistrationRequired(false);
+ return kryo;
+ }
+ };
+
+ outputThreadLocal = new ThreadLocal<Output>() {
+ @Override
+ protected Output initialValue() {
+ Output output = new Output(initialBufferSize, maxBufferSize);
+ return output;
+ }
+ };
+
+ }
+
+ @Override
+ public Object deserialize(ByteBuffer rawMessage) {
+ Input input = new Input(rawMessage.array());
+ try {
+ return kryoThreadLocal.get().readClassAndObject(input);
+ } finally {
+ input.close();
}
+ }
- /**
- *
- * @param classLoader
- * classloader able to handle classes to serialize/deserialize. For instance, application-level events
- * can only be handled by the application classloader.
- */
- @Inject
- public SamoaSerializer(@Assisted final ClassLoader classLoader) {
- kryoThreadLocal = new ThreadLocal<Kryo>() {
-
- @Override
- protected Kryo initialValue() {
- Kryo kryo = new Kryo();
- kryo.setClassLoader(classLoader);
- kryo.register(AttributeContentEvent.class, new AttributeContentEvent.AttributeCEFullPrecSerializer());
- kryo.register(ComputeContentEvent.class, new ComputeContentEvent.ComputeCEFullPrecSerializer());
- kryo.setRegistrationRequired(false);
- return kryo;
- }
- };
-
- outputThreadLocal = new ThreadLocal<Output>() {
- @Override
- protected Output initialValue() {
- Output output = new Output(initialBufferSize, maxBufferSize);
- return output;
- }
- };
-
+ @SuppressWarnings("resource")
+ @Override
+ public ByteBuffer serialize(Object message) {
+ Output output = outputThreadLocal.get();
+ try {
+ kryoThreadLocal.get().writeClassAndObject(output, message);
+ return ByteBuffer.wrap(output.toBytes());
+ } finally {
+ output.clear();
}
-
- @Override
- public Object deserialize(ByteBuffer rawMessage) {
- Input input = new Input(rawMessage.array());
- try {
- return kryoThreadLocal.get().readClassAndObject(input);
- } finally {
- input.close();
- }
- }
-
- @SuppressWarnings("resource")
- @Override
- public ByteBuffer serialize(Object message) {
- Output output = outputThreadLocal.get();
- try {
- kryoThreadLocal.get().writeClassAndObject(output, message);
- return ByteBuffer.wrap(output.toBytes());
- } finally {
- output.clear();
- }
- }
+ }
}
diff --git a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSerializerModule.java b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSerializerModule.java
index 311e449..a367eb5 100644
--- a/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSerializerModule.java
+++ b/samoa-s4/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSerializerModule.java
@@ -26,10 +26,10 @@
public class SamoaSerializerModule extends AbstractModule {
- @Override
- protected void configure() {
- bind(SerializerDeserializer.class).to(SamoaSerializer.class);
-
- }
+ @Override
+ protected void configure() {
+ bind(SerializerDeserializer.class).to(SamoaSerializer.class);
+
+ }
}
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/SamzaDoTask.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/SamzaDoTask.java
index 45dd901..6c6103c 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/SamzaDoTask.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/SamzaDoTask.java
@@ -44,184 +44,183 @@
*/
public class SamzaDoTask {
- private static final Logger logger = LoggerFactory.getLogger(SamzaDoTask.class);
-
- private static final String LOCAL_MODE = "local";
- private static final String YARN_MODE = "yarn";
-
- // FLAGS
- private static final String YARN_CONF_FLAG = "--yarn_home";
- private static final String MODE_FLAG = "--mode";
- private static final String ZK_FLAG = "--zookeeper";
- private static final String KAFKA_FLAG = "--kafka";
- private static final String KAFKA_REPLICATION_FLAG = "--kafka_replication_factor";
- private static final String CHECKPOINT_FREQ_FLAG = "--checkpoint_frequency";
- private static final String JAR_PACKAGE_FLAG = "--jar_package";
- private static final String SAMOA_HDFS_DIR_FLAG = "--samoa_hdfs_dir";
- private static final String AM_MEMORY_FLAG = "--yarn_am_mem";
- private static final String CONTAINER_MEMORY_FLAG = "--yarn_container_mem";
- private static final String PI_PER_CONTAINER_FLAG = "--pi_per_container";
-
- private static final String KRYO_REGISTER_FLAG = "--kryo_register";
-
- // config values
- private static int kafkaReplicationFactor = 1;
- private static int checkpointFrequency = 60000;
- private static String kafka = "localhost:9092";
- private static String zookeeper = "localhost:2181";
- private static boolean isLocal = true;
- private static String yarnConfHome = null;
- private static String samoaHDFSDir = null;
- private static String jarPackagePath = null;
- private static int amMem = 1024;
- private static int containerMem = 1024;
- private static int piPerContainer = 2;
- private static String kryoRegisterFile = null;
-
- /*
- * 1. Read arguments
- * 2. Construct topology/task
- * 3. Upload the JAR to HDFS if we are running on YARN
- * 4. Submit topology to SamzaEngine
- */
- public static void main(String[] args) {
- // Read arguments
- List<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
- parseArguments(tmpArgs);
-
- args = tmpArgs.toArray(new String[0]);
-
- // Init Task
- StringBuilder cliString = new StringBuilder();
- for (int i = 0; i < args.length; i++) {
- cliString.append(" ").append(args[i]);
+ private static final Logger logger = LoggerFactory.getLogger(SamzaDoTask.class);
+
+ private static final String LOCAL_MODE = "local";
+ private static final String YARN_MODE = "yarn";
+
+ // FLAGS
+ private static final String YARN_CONF_FLAG = "--yarn_home";
+ private static final String MODE_FLAG = "--mode";
+ private static final String ZK_FLAG = "--zookeeper";
+ private static final String KAFKA_FLAG = "--kafka";
+ private static final String KAFKA_REPLICATION_FLAG = "--kafka_replication_factor";
+ private static final String CHECKPOINT_FREQ_FLAG = "--checkpoint_frequency";
+ private static final String JAR_PACKAGE_FLAG = "--jar_package";
+ private static final String SAMOA_HDFS_DIR_FLAG = "--samoa_hdfs_dir";
+ private static final String AM_MEMORY_FLAG = "--yarn_am_mem";
+ private static final String CONTAINER_MEMORY_FLAG = "--yarn_container_mem";
+ private static final String PI_PER_CONTAINER_FLAG = "--pi_per_container";
+
+ private static final String KRYO_REGISTER_FLAG = "--kryo_register";
+
+ // config values
+ private static int kafkaReplicationFactor = 1;
+ private static int checkpointFrequency = 60000;
+ private static String kafka = "localhost:9092";
+ private static String zookeeper = "localhost:2181";
+ private static boolean isLocal = true;
+ private static String yarnConfHome = null;
+ private static String samoaHDFSDir = null;
+ private static String jarPackagePath = null;
+ private static int amMem = 1024;
+ private static int containerMem = 1024;
+ private static int piPerContainer = 2;
+ private static String kryoRegisterFile = null;
+
+ /*
+ * 1. Read arguments 2. Construct topology/task 3. Upload the JAR to HDFS if
+ * we are running on YARN 4. Submit topology to SamzaEngine
+ */
+ public static void main(String[] args) {
+ // Read arguments
+ List<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
+ parseArguments(tmpArgs);
+
+ args = tmpArgs.toArray(new String[0]);
+
+ // Init Task
+ StringBuilder cliString = new StringBuilder();
+ for (int i = 0; i < args.length; i++) {
+ cliString.append(" ").append(args[i]);
+ }
+ logger.debug("Command line string = {}", cliString.toString());
+ System.out.println("Command line string = " + cliString.toString());
+
+ Task task = null;
+ try {
+ task = (Task) ClassOption.cliStringToObject(cliString.toString(), Task.class, null);
+ logger.info("Sucessfully instantiating {}", task.getClass().getCanonicalName());
+ } catch (Exception e) {
+ logger.error("Fail to initialize the task", e);
+ System.out.println("Fail to initialize the task" + e);
+ return;
+ }
+ task.setFactory(new SamzaComponentFactory());
+ task.init();
+
+ // Upload JAR file to HDFS
+ String hdfsPath = null;
+ if (!isLocal) {
+ Path path = FileSystems.getDefault().getPath(jarPackagePath);
+ hdfsPath = uploadJarToHDFS(path.toFile());
+ if (hdfsPath == null) {
+ System.out.println("Fail uploading JAR file \"" + path.toAbsolutePath().toString() + "\" to HDFS.");
+ return;
+ }
+ }
+
+ // Set parameters
+ SamzaEngine.getEngine()
+ .setLocalMode(isLocal)
+ .setZooKeeper(zookeeper)
+ .setKafka(kafka)
+ .setYarnPackage(hdfsPath)
+ .setKafkaReplicationFactor(kafkaReplicationFactor)
+ .setConfigHome(yarnConfHome)
+ .setAMMemory(amMem)
+ .setContainerMemory(containerMem)
+ .setPiPerContainerRatio(piPerContainer)
+ .setKryoRegisterFile(kryoRegisterFile)
+ .setCheckpointFrequency(checkpointFrequency);
+
+ // Submit topology
+ SamzaEngine.submitTopology((SamzaTopology) task.getTopology());
+
+ }
+
+ private static boolean isLocalMode(String mode) {
+ return mode.equals(LOCAL_MODE);
+ }
+
+ private static void parseArguments(List<String> args) {
+ for (int i = args.size() - 1; i >= 0; i--) {
+ String arg = args.get(i).trim();
+ String[] splitted = arg.split("=", 2);
+
+ if (splitted.length >= 2) {
+ // YARN config folder which contains conf/core-site.xml,
+ // conf/hdfs-site.xml, conf/yarn-site.xml
+ if (splitted[0].equals(YARN_CONF_FLAG)) {
+ yarnConfHome = splitted[1];
+ args.remove(i);
}
- logger.debug("Command line string = {}", cliString.toString());
- System.out.println("Command line string = " + cliString.toString());
-
- Task task = null;
- try {
- task = (Task) ClassOption.cliStringToObject(cliString.toString(), Task.class, null);
- logger.info("Sucessfully instantiating {}", task.getClass().getCanonicalName());
- } catch (Exception e) {
- logger.error("Fail to initialize the task", e);
- System.out.println("Fail to initialize the task" + e);
- return;
+ // host:port for zookeeper cluster
+ else if (splitted[0].equals(ZK_FLAG)) {
+ zookeeper = splitted[1];
+ args.remove(i);
}
- task.setFactory(new SamzaComponentFactory());
- task.init();
-
- // Upload JAR file to HDFS
- String hdfsPath = null;
- if (!isLocal) {
- Path path = FileSystems.getDefault().getPath(jarPackagePath);
- hdfsPath = uploadJarToHDFS(path.toFile());
- if(hdfsPath == null) {
- System.out.println("Fail uploading JAR file \""+path.toAbsolutePath().toString()+"\" to HDFS.");
- return;
- }
- }
-
- // Set parameters
- SamzaEngine.getEngine()
- .setLocalMode(isLocal)
- .setZooKeeper(zookeeper)
- .setKafka(kafka)
- .setYarnPackage(hdfsPath)
- .setKafkaReplicationFactor(kafkaReplicationFactor)
- .setConfigHome(yarnConfHome)
- .setAMMemory(amMem)
- .setContainerMemory(containerMem)
- .setPiPerContainerRatio(piPerContainer)
- .setKryoRegisterFile(kryoRegisterFile)
- .setCheckpointFrequency(checkpointFrequency);
-
- // Submit topology
- SamzaEngine.submitTopology((SamzaTopology)task.getTopology());
-
- }
-
- private static boolean isLocalMode(String mode) {
- return mode.equals(LOCAL_MODE);
- }
-
- private static void parseArguments(List<String> args) {
- for (int i=args.size()-1; i>=0; i--) {
- String arg = args.get(i).trim();
- String[] splitted = arg.split("=",2);
-
- if (splitted.length >= 2) {
- // YARN config folder which contains conf/core-site.xml,
- // conf/hdfs-site.xml, conf/yarn-site.xml
- if (splitted[0].equals(YARN_CONF_FLAG)) {
- yarnConfHome = splitted[1];
- args.remove(i);
- }
- // host:port for zookeeper cluster
- else if (splitted[0].equals(ZK_FLAG)) {
- zookeeper = splitted[1];
- args.remove(i);
- }
- // host:port,... for kafka broker(s)
- else if (splitted[0].equals(KAFKA_FLAG)) {
- kafka = splitted[1];
- args.remove(i);
- }
- // whether we are running Samza in Local mode or YARN mode
- else if (splitted[0].equals(MODE_FLAG)) {
- isLocal = isLocalMode(splitted[1]);
- args.remove(i);
- }
- // memory requirement for YARN application master
- else if (splitted[0].equals(AM_MEMORY_FLAG)) {
- amMem = Integer.parseInt(splitted[1]);
- args.remove(i);
- }
- // memory requirement for YARN worker container
- else if (splitted[0].equals(CONTAINER_MEMORY_FLAG)) {
- containerMem = Integer.parseInt(splitted[1]);
- args.remove(i);
- }
- // the path to JAR archive that we need to upload to HDFS
- else if (splitted[0].equals(JAR_PACKAGE_FLAG)) {
- jarPackagePath = splitted[1];
- args.remove(i);
- }
- // the HDFS dir for SAMOA files
- else if (splitted[0].equals(SAMOA_HDFS_DIR_FLAG)) {
- samoaHDFSDir = splitted[1];
- if (samoaHDFSDir.length() < 1) samoaHDFSDir = null;
- args.remove(i);
- }
- // number of max PI instances per container
- // this will be used to compute the number of containers
- // AM will request for the job
- else if (splitted[0].equals(PI_PER_CONTAINER_FLAG)) {
- piPerContainer = Integer.parseInt(splitted[1]);
- args.remove(i);
- }
- // kafka streams replication factor
- else if (splitted[0].equals(KAFKA_REPLICATION_FLAG)) {
- kafkaReplicationFactor = Integer.parseInt(splitted[1]);
- args.remove(i);
- }
- // checkpoint frequency in ms
- else if (splitted[0].equals(CHECKPOINT_FREQ_FLAG)) {
- checkpointFrequency = Integer.parseInt(splitted[1]);
- args.remove(i);
- }
- // the file contains registration information for Kryo serializer
- else if (splitted[0].equals(KRYO_REGISTER_FLAG)) {
- kryoRegisterFile = splitted[1];
- args.remove(i);
- }
- }
- }
- }
-
- private static String uploadJarToHDFS(File file) {
- SystemsUtils.setHadoopConfigHome(yarnConfHome);
- SystemsUtils.setSAMOADir(samoaHDFSDir);
- return SystemsUtils.copyToHDFS(file, file.getName());
- }
+ // host:port,... for kafka broker(s)
+ else if (splitted[0].equals(KAFKA_FLAG)) {
+ kafka = splitted[1];
+ args.remove(i);
+ }
+ // whether we are running Samza in Local mode or YARN mode
+ else if (splitted[0].equals(MODE_FLAG)) {
+ isLocal = isLocalMode(splitted[1]);
+ args.remove(i);
+ }
+ // memory requirement for YARN application master
+ else if (splitted[0].equals(AM_MEMORY_FLAG)) {
+ amMem = Integer.parseInt(splitted[1]);
+ args.remove(i);
+ }
+ // memory requirement for YARN worker container
+ else if (splitted[0].equals(CONTAINER_MEMORY_FLAG)) {
+ containerMem = Integer.parseInt(splitted[1]);
+ args.remove(i);
+ }
+ // the path to JAR archive that we need to upload to HDFS
+ else if (splitted[0].equals(JAR_PACKAGE_FLAG)) {
+ jarPackagePath = splitted[1];
+ args.remove(i);
+ }
+ // the HDFS dir for SAMOA files
+ else if (splitted[0].equals(SAMOA_HDFS_DIR_FLAG)) {
+ samoaHDFSDir = splitted[1];
+ if (samoaHDFSDir.length() < 1)
+ samoaHDFSDir = null;
+ args.remove(i);
+ }
+ // number of max PI instances per container
+ // this will be used to compute the number of containers
+ // AM will request for the job
+ else if (splitted[0].equals(PI_PER_CONTAINER_FLAG)) {
+ piPerContainer = Integer.parseInt(splitted[1]);
+ args.remove(i);
+ }
+ // kafka streams replication factor
+ else if (splitted[0].equals(KAFKA_REPLICATION_FLAG)) {
+ kafkaReplicationFactor = Integer.parseInt(splitted[1]);
+ args.remove(i);
+ }
+ // checkpoint frequency in ms
+ else if (splitted[0].equals(CHECKPOINT_FREQ_FLAG)) {
+ checkpointFrequency = Integer.parseInt(splitted[1]);
+ args.remove(i);
+ }
+ // the file contains registration information for Kryo serializer
+ else if (splitted[0].equals(KRYO_REGISTER_FLAG)) {
+ kryoRegisterFile = splitted[1];
+ args.remove(i);
+ }
+ }
+ }
+ }
+
+ private static String uploadJarToHDFS(File file) {
+ SystemsUtils.setHadoopConfigHome(yarnConfHome);
+ SystemsUtils.setSAMOADir(samoaHDFSDir);
+ return SystemsUtils.copyToHDFS(file, file.getName());
+ }
}
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSystemFactory.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSystemFactory.java
index 362e0a5..1a4b57f 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSystemFactory.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamoaSystemFactory.java
@@ -32,26 +32,25 @@
import com.yahoo.labs.samoa.topology.impl.SamzaEntranceProcessingItem.SamoaSystemConsumer;
/**
- * Implementation of Samza's SystemFactory
- * Samza will use this factory to get our custom consumer
- * which gets the events from SAMOA EntranceProcessor
- * and feed them to EntranceProcessingItem task
+ * Implementation of Samza's SystemFactory Samza will use this factory to get
+ * our custom consumer which gets the events from SAMOA EntranceProcessor and
+ * feed them to EntranceProcessingItem task
*
* @author Anh Thu Vu
*/
public class SamoaSystemFactory implements SystemFactory {
- @Override
- public SystemAdmin getAdmin(String systemName, Config config) {
- return new SinglePartitionWithoutOffsetsSystemAdmin();
- }
+ @Override
+ public SystemAdmin getAdmin(String systemName, Config config) {
+ return new SinglePartitionWithoutOffsetsSystemAdmin();
+ }
- @Override
- public SystemConsumer getConsumer(String systemName, Config config, MetricsRegistry registry) {
- return new SamoaSystemConsumer(systemName, config);
- }
+ @Override
+ public SystemConsumer getConsumer(String systemName, Config config, MetricsRegistry registry) {
+ return new SamoaSystemConsumer(systemName, config);
+ }
- @Override
- public SystemProducer getProducer(String systemName, Config config, MetricsRegistry registry) {
- throw new SamzaException("This implementation is not supposed to produce anything.");
- }
+ @Override
+ public SystemProducer getProducer(String systemName, Config config, MetricsRegistry registry) {
+ throw new SamzaException("This implementation is not supposed to produce anything.");
+ }
}
\ No newline at end of file
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaComponentFactory.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaComponentFactory.java
index d71d97b..278b1f2 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaComponentFactory.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaComponentFactory.java
@@ -35,28 +35,28 @@
* @author Anh Thu Vu
*/
public class SamzaComponentFactory implements ComponentFactory {
- @Override
- public ProcessingItem createPi(Processor processor) {
- return this.createPi(processor, 1);
- }
+ @Override
+ public ProcessingItem createPi(Processor processor) {
+ return this.createPi(processor, 1);
+ }
- @Override
- public ProcessingItem createPi(Processor processor, int parallelism) {
- return new SamzaProcessingItem(processor, parallelism);
- }
+ @Override
+ public ProcessingItem createPi(Processor processor, int parallelism) {
+ return new SamzaProcessingItem(processor, parallelism);
+ }
- @Override
- public EntranceProcessingItem createEntrancePi(EntranceProcessor entranceProcessor) {
- return new SamzaEntranceProcessingItem(entranceProcessor);
- }
-
- @Override
- public Stream createStream(IProcessingItem sourcePi) {
- return new SamzaStream(sourcePi);
- }
-
- @Override
- public Topology createTopology(String topoName) {
- return new SamzaTopology(topoName);
- }
+ @Override
+ public EntranceProcessingItem createEntrancePi(EntranceProcessor entranceProcessor) {
+ return new SamzaEntranceProcessingItem(entranceProcessor);
+ }
+
+ @Override
+ public Stream createStream(IProcessingItem sourcePi) {
+ return new SamzaStream(sourcePi);
+ }
+
+ @Override
+ public Topology createTopology(String topoName) {
+ return new SamzaTopology(topoName);
+ }
}
\ No newline at end of file
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaEngine.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaEngine.java
index 7339443..e3141f8 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaEngine.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaEngine.java
@@ -35,163 +35,162 @@
import com.yahoo.labs.samoa.utils.SystemsUtils;
/**
- * This class will submit a list of Samza jobs with
- * the Configs generated from the input topology
+ * This class will submit a list of Samza jobs with the Configs generated from
+ * the input topology
*
* @author Anh Thu Vu
- *
+ *
*/
public class SamzaEngine {
-
- private static final Logger logger = LoggerFactory.getLogger(SamzaEngine.class);
-
- /*
- * Singleton instance
- */
- private static SamzaEngine engine = new SamzaEngine();
-
- private String zookeeper;
- private String kafka;
- private int kafkaReplicationFactor;
- private boolean isLocalMode;
- private String yarnPackagePath;
- private String yarnConfHome;
-
- private String kryoRegisterFile;
-
- private int amMem;
- private int containerMem;
- private int piPerContainerRatio;
-
- private int checkpointFrequency;
-
- private void _submitTopology(SamzaTopology topology) {
-
- // Setup SamzaConfigFactory
- SamzaConfigFactory configFactory = new SamzaConfigFactory();
- configFactory.setLocalMode(isLocalMode)
- .setZookeeper(zookeeper)
- .setKafka(kafka)
- .setYarnPackage(yarnPackagePath)
- .setAMMemory(amMem)
- .setContainerMemory(containerMem)
- .setPiPerContainerRatio(piPerContainerRatio)
- .setKryoRegisterFile(kryoRegisterFile)
- .setCheckpointFrequency(checkpointFrequency)
- .setReplicationFactor(kafkaReplicationFactor);
-
- // Generate the list of Configs
- List<MapConfig> configs;
- try {
- // ConfigFactory generate a list of configs
- // Serialize a map of PIs and store in a file in the jar at jarFilePath
- // (in dat/ folder)
- configs = configFactory.getMapConfigsForTopology(topology);
- } catch (Exception e) {
- e.printStackTrace();
- return;
- }
-
- // Create kafka streams
- Set<Stream> streams = topology.getStreams();
- for (Stream stream:streams) {
- SamzaStream samzaStream = (SamzaStream) stream;
- List<SamzaSystemStream> systemStreams = samzaStream.getSystemStreams();
- for (SamzaSystemStream systemStream:systemStreams) {
- // all streams should be kafka streams
- SystemsUtils.createKafkaTopic(systemStream.getStream(),systemStream.getParallelism(),kafkaReplicationFactor);
- }
- }
-
- // Submit the jobs with those configs
- for (MapConfig config:configs) {
- logger.info("Config:{}",config);
- JobRunner jobRunner = new JobRunner(config);
- jobRunner.run();
- }
- }
- private void _setupSystemsUtils() {
- // Setup Utils
- if (!isLocalMode)
- SystemsUtils.setHadoopConfigHome(yarnConfHome);
- SystemsUtils.setZookeeper(zookeeper);
- }
-
- /*
- * Setter methods
- */
- public static SamzaEngine getEngine() {
- return engine;
- }
-
- public SamzaEngine setZooKeeper(String zk) {
- this.zookeeper = zk;
- return this;
- }
-
- public SamzaEngine setKafka(String kafka) {
- this.kafka = kafka;
- return this;
- }
-
- public SamzaEngine setKafkaReplicationFactor(int replicationFactor) {
- this.kafkaReplicationFactor = replicationFactor;
- return this;
- }
-
- public SamzaEngine setCheckpointFrequency(int freq) {
- this.checkpointFrequency = freq;
- return this;
- }
-
- public SamzaEngine setLocalMode(boolean isLocal) {
- this.isLocalMode = isLocal;
- return this;
- }
-
- public SamzaEngine setYarnPackage(String yarnPackagePath) {
- this.yarnPackagePath = yarnPackagePath;
- return this;
- }
-
- public SamzaEngine setConfigHome(String configHome) {
- this.yarnConfHome = configHome;
- return this;
- }
-
- public SamzaEngine setAMMemory(int mem) {
- this.amMem = mem;
- return this;
- }
-
- public SamzaEngine setContainerMemory(int mem) {
- this.containerMem = mem;
- return this;
- }
-
- public SamzaEngine setPiPerContainerRatio(int piPerContainer) {
- this.piPerContainerRatio = piPerContainer;
- return this;
- }
-
- public SamzaEngine setKryoRegisterFile(String registerFile) {
- this.kryoRegisterFile = registerFile;
- return this;
- }
-
- /**
- * Submit a list of Samza jobs correspond to the submitted
- * topology
- *
- * @param topo
- * the submitted topology
- */
- public static void submitTopology(SamzaTopology topo) {
- // Setup SystemsUtils
- engine._setupSystemsUtils();
-
- // Submit topology
- engine._submitTopology(topo);
- }
+ private static final Logger logger = LoggerFactory.getLogger(SamzaEngine.class);
+
+ /*
+ * Singleton instance
+ */
+ private static SamzaEngine engine = new SamzaEngine();
+
+ private String zookeeper;
+ private String kafka;
+ private int kafkaReplicationFactor;
+ private boolean isLocalMode;
+ private String yarnPackagePath;
+ private String yarnConfHome;
+
+ private String kryoRegisterFile;
+
+ private int amMem;
+ private int containerMem;
+ private int piPerContainerRatio;
+
+ private int checkpointFrequency;
+
+ private void _submitTopology(SamzaTopology topology) {
+
+ // Setup SamzaConfigFactory
+ SamzaConfigFactory configFactory = new SamzaConfigFactory();
+ configFactory.setLocalMode(isLocalMode)
+ .setZookeeper(zookeeper)
+ .setKafka(kafka)
+ .setYarnPackage(yarnPackagePath)
+ .setAMMemory(amMem)
+ .setContainerMemory(containerMem)
+ .setPiPerContainerRatio(piPerContainerRatio)
+ .setKryoRegisterFile(kryoRegisterFile)
+ .setCheckpointFrequency(checkpointFrequency)
+ .setReplicationFactor(kafkaReplicationFactor);
+
+ // Generate the list of Configs
+ List<MapConfig> configs;
+ try {
+ // ConfigFactory generate a list of configs
+ // Serialize a map of PIs and store in a file in the jar at jarFilePath
+ // (in dat/ folder)
+ configs = configFactory.getMapConfigsForTopology(topology);
+ } catch (Exception e) {
+ e.printStackTrace();
+ return;
+ }
+
+ // Create kafka streams
+ Set<Stream> streams = topology.getStreams();
+ for (Stream stream : streams) {
+ SamzaStream samzaStream = (SamzaStream) stream;
+ List<SamzaSystemStream> systemStreams = samzaStream.getSystemStreams();
+ for (SamzaSystemStream systemStream : systemStreams) {
+ // all streams should be kafka streams
+ SystemsUtils.createKafkaTopic(systemStream.getStream(), systemStream.getParallelism(), kafkaReplicationFactor);
+ }
+ }
+
+ // Submit the jobs with those configs
+ for (MapConfig config : configs) {
+ logger.info("Config:{}", config);
+ JobRunner jobRunner = new JobRunner(config);
+ jobRunner.run();
+ }
+ }
+
+ private void _setupSystemsUtils() {
+ // Setup Utils
+ if (!isLocalMode)
+ SystemsUtils.setHadoopConfigHome(yarnConfHome);
+ SystemsUtils.setZookeeper(zookeeper);
+ }
+
+ /*
+ * Setter methods
+ */
+ public static SamzaEngine getEngine() {
+ return engine;
+ }
+
+ public SamzaEngine setZooKeeper(String zk) {
+ this.zookeeper = zk;
+ return this;
+ }
+
+ public SamzaEngine setKafka(String kafka) {
+ this.kafka = kafka;
+ return this;
+ }
+
+ public SamzaEngine setKafkaReplicationFactor(int replicationFactor) {
+ this.kafkaReplicationFactor = replicationFactor;
+ return this;
+ }
+
+ public SamzaEngine setCheckpointFrequency(int freq) {
+ this.checkpointFrequency = freq;
+ return this;
+ }
+
+ public SamzaEngine setLocalMode(boolean isLocal) {
+ this.isLocalMode = isLocal;
+ return this;
+ }
+
+ public SamzaEngine setYarnPackage(String yarnPackagePath) {
+ this.yarnPackagePath = yarnPackagePath;
+ return this;
+ }
+
+ public SamzaEngine setConfigHome(String configHome) {
+ this.yarnConfHome = configHome;
+ return this;
+ }
+
+ public SamzaEngine setAMMemory(int mem) {
+ this.amMem = mem;
+ return this;
+ }
+
+ public SamzaEngine setContainerMemory(int mem) {
+ this.containerMem = mem;
+ return this;
+ }
+
+ public SamzaEngine setPiPerContainerRatio(int piPerContainer) {
+ this.piPerContainerRatio = piPerContainer;
+ return this;
+ }
+
+ public SamzaEngine setKryoRegisterFile(String registerFile) {
+ this.kryoRegisterFile = registerFile;
+ return this;
+ }
+
+ /**
+ * Submit a list of Samza jobs correspond to the submitted topology
+ *
+ * @param topo
+ * the submitted topology
+ */
+ public static void submitTopology(SamzaTopology topo) {
+ // Setup SystemsUtils
+ engine._setupSystemsUtils();
+
+ // Submit topology
+ engine._submitTopology(topo);
+ }
}
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaEntranceProcessingItem.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaEntranceProcessingItem.java
index e89d789..6eea7cb 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaEntranceProcessingItem.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaEntranceProcessingItem.java
@@ -44,179 +44,196 @@
import com.yahoo.labs.samoa.utils.SystemsUtils;
/**
- * EntranceProcessingItem for Samza
- * which is also a Samza task (StreamTask & InitableTask)
+ * EntranceProcessingItem for Samza which is also a Samza task (StreamTask &
+ * InitableTask)
*
* @author Anh Thu Vu
- *
+ *
*/
public class SamzaEntranceProcessingItem extends AbstractEntranceProcessingItem
- implements SamzaProcessingNode, Serializable, StreamTask, InitableTask {
+ implements SamzaProcessingNode, Serializable, StreamTask, InitableTask {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 7157734520046135039L;
-
- /*
- * Constructors
- */
- public SamzaEntranceProcessingItem(EntranceProcessor processor) {
- super(processor);
- }
-
- // Need this so Samza can initialize a StreamTask
- public SamzaEntranceProcessingItem() {}
-
- /*
- * Simple setters, getters
- */
- @Override
- public int addOutputStream(SamzaStream stream) {
- this.setOutputStream(stream);
- return 1; // entrance PI should have only 1 output stream
- }
-
- /*
- * Serialization
- */
- private Object writeReplace() {
- return new SerializationProxy(this);
- }
-
- private static class SerializationProxy implements Serializable {
- /**
+ private static final long serialVersionUID = 7157734520046135039L;
+
+ /*
+ * Constructors
+ */
+ public SamzaEntranceProcessingItem(EntranceProcessor processor) {
+ super(processor);
+ }
+
+ // Need this so Samza can initialize a StreamTask
+ public SamzaEntranceProcessingItem() {
+ }
+
+ /*
+ * Simple setters, getters
+ */
+ @Override
+ public int addOutputStream(SamzaStream stream) {
+ this.setOutputStream(stream);
+ return 1; // entrance PI should have only 1 output stream
+ }
+
+ /*
+ * Serialization
+ */
+ private Object writeReplace() {
+ return new SerializationProxy(this);
+ }
+
+ private static class SerializationProxy implements Serializable {
+ /**
*
*/
- private static final long serialVersionUID = 313907132721414634L;
-
- private EntranceProcessor processor;
- private SamzaStream outputStream;
- private String name;
-
- public SerializationProxy(SamzaEntranceProcessingItem epi) {
- this.processor = epi.getProcessor();
- this.outputStream = (SamzaStream)epi.getOutputStream();
- this.name = epi.getName();
- }
- }
-
- /*
- * Implement Samza Task
- */
- @Override
- public void init(Config config, TaskContext context) throws Exception {
- String yarnConfHome = config.get(SamzaConfigFactory.YARN_CONF_HOME_KEY);
- if (yarnConfHome != null && yarnConfHome.length() > 0) // if the property is set , otherwise, assume we are running in
- // local mode and ignore this
- SystemsUtils.setHadoopConfigHome(yarnConfHome);
-
- String filename = config.get(SamzaConfigFactory.FILE_KEY);
- String filesystem = config.get(SamzaConfigFactory.FILESYSTEM_KEY);
-
- this.setName(config.get(SamzaConfigFactory.JOB_NAME_KEY));
- SerializationProxy wrapper = (SerializationProxy) SystemsUtils.deserializeObjectFromFileAndKey(filesystem, filename, this.getName());
- this.setOutputStream(wrapper.outputStream);
- SamzaStream output = (SamzaStream)this.getOutputStream();
- if (output != null) // if output stream exists, set it up
- output.onCreate();
- }
+ private static final long serialVersionUID = 313907132721414634L;
- @Override
- public void process(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator) throws Exception {
- SamzaStream output = (SamzaStream)this.getOutputStream();
- if (output == null) return; // if there is no output stream, do nothing
- output.setCollector(collector);
- ContentEvent event = (ContentEvent) envelope.getMessage();
- output.put(event);
- }
-
- /*
- * Implementation of Samza's SystemConsumer to get events from source
- * and feed to SAMOA system
- *
- */
- /* Current implementation: buffer the incoming events and send a batch
- * of them when poll() is called by Samza system.
- *
- * Currently: it has a "soft" limit on the size of the buffer:
- * when the buffer size reaches the limit, the reading thread will sleep
- * for 100ms.
- * A hard limit can be achieved by overriding the method
- * protected BlockingQueue<IncomingMessageEnvelope> newBlockingQueue()
- * of BlockingEnvelopeMap
- * But then we have handle the case when the queue is full.
- *
- */
- public static class SamoaSystemConsumer extends BlockingEnvelopeMap {
-
- private EntranceProcessor entranceProcessor = null;
- private SystemStreamPartition systemStreamPartition;
-
- private static final Logger logger = LoggerFactory.getLogger(SamoaSystemConsumer.class);
+ private EntranceProcessor processor;
+ private SamzaStream outputStream;
+ private String name;
- public SamoaSystemConsumer(String systemName, Config config) {
- String yarnConfHome = config.get(SamzaConfigFactory.YARN_CONF_HOME_KEY);
- if (yarnConfHome != null && yarnConfHome.length() > 0) // if the property is set , otherwise, assume we are running in
- // local mode and ignore this
- SystemsUtils.setHadoopConfigHome(yarnConfHome);
-
- String filename = config.get(SamzaConfigFactory.FILE_KEY);
- String filesystem = config.get(SamzaConfigFactory.FILESYSTEM_KEY);
- String name = config.get(SamzaConfigFactory.JOB_NAME_KEY);
- SerializationProxy wrapper = (SerializationProxy) SystemsUtils.deserializeObjectFromFileAndKey(filesystem, filename, name);
-
- this.entranceProcessor = wrapper.processor;
- this.entranceProcessor.onCreate(0);
-
- // Internal stream from SystemConsumer to EntranceTask, so we
- // need only one partition
- this.systemStreamPartition = new SystemStreamPartition(systemName, wrapper.name, new Partition(0));
- }
-
- @Override
- public void start() {
- Thread processorPollingThread = new Thread(
- new Runnable() {
- @Override
- public void run() {
- try {
- pollingEntranceProcessor();
- setIsAtHead(systemStreamPartition, true);
- } catch (InterruptedException e) {
- e.getStackTrace();
- stop();
- }
- }
- }
- );
+ public SerializationProxy(SamzaEntranceProcessingItem epi) {
+ this.processor = epi.getProcessor();
+ this.outputStream = (SamzaStream) epi.getOutputStream();
+ this.name = epi.getName();
+ }
+ }
- processorPollingThread.start();
- }
+ /*
+ * Implement Samza Task
+ */
+ @Override
+ public void init(Config config, TaskContext context) throws Exception {
+ String yarnConfHome = config.get(SamzaConfigFactory.YARN_CONF_HOME_KEY);
+ if (yarnConfHome != null && yarnConfHome.length() > 0) // if the property is
+ // set , otherwise,
+ // assume we are
+ // running in
+ // local mode and ignore this
+ SystemsUtils.setHadoopConfigHome(yarnConfHome);
- @Override
- public void stop() {
+ String filename = config.get(SamzaConfigFactory.FILE_KEY);
+ String filesystem = config.get(SamzaConfigFactory.FILESYSTEM_KEY);
- }
-
- private void pollingEntranceProcessor() throws InterruptedException {
- int messageCnt = 0;
- while(!this.entranceProcessor.isFinished()) {
- messageCnt = this.getNumMessagesInQueue(systemStreamPartition);
- if (this.entranceProcessor.hasNext() && messageCnt < 10000) { // soft limit on the size of the queue
- this.put(systemStreamPartition, new IncomingMessageEnvelope(systemStreamPartition,null, null,this.entranceProcessor.nextEvent()));
- } else {
- try {
- Thread.sleep(100);
- } catch (InterruptedException e) {
- break;
- }
- }
- }
-
- // Send last event
- this.put(systemStreamPartition, new IncomingMessageEnvelope(systemStreamPartition,null, null,this.entranceProcessor.nextEvent()));
- }
-
- }
+ this.setName(config.get(SamzaConfigFactory.JOB_NAME_KEY));
+ SerializationProxy wrapper = (SerializationProxy) SystemsUtils.deserializeObjectFromFileAndKey(filesystem,
+ filename, this.getName());
+ this.setOutputStream(wrapper.outputStream);
+ SamzaStream output = (SamzaStream) this.getOutputStream();
+ if (output != null) // if output stream exists, set it up
+ output.onCreate();
+ }
+
+ @Override
+ public void process(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator)
+ throws Exception {
+ SamzaStream output = (SamzaStream) this.getOutputStream();
+ if (output == null)
+ return; // if there is no output stream, do nothing
+ output.setCollector(collector);
+ ContentEvent event = (ContentEvent) envelope.getMessage();
+ output.put(event);
+ }
+
+ /*
+ * Implementation of Samza's SystemConsumer to get events from source and feed
+ * to SAMOA system
+ */
+ /*
+ * Current implementation: buffer the incoming events and send a batch of them
+ * when poll() is called by Samza system.
+ *
+ * Currently: it has a "soft" limit on the size of the buffer: when the buffer
+ * size reaches the limit, the reading thread will sleep for 100ms. A hard
+ * limit can be achieved by overriding the method protected
+ * BlockingQueue<IncomingMessageEnvelope> newBlockingQueue() of
+ * BlockingEnvelopeMap But then we have handle the case when the queue is
+ * full.
+ */
+ public static class SamoaSystemConsumer extends BlockingEnvelopeMap {
+
+ private EntranceProcessor entranceProcessor = null;
+ private SystemStreamPartition systemStreamPartition;
+
+ private static final Logger logger = LoggerFactory.getLogger(SamoaSystemConsumer.class);
+
+ public SamoaSystemConsumer(String systemName, Config config) {
+ String yarnConfHome = config.get(SamzaConfigFactory.YARN_CONF_HOME_KEY);
+ if (yarnConfHome != null && yarnConfHome.length() > 0) // if the property
+ // is set ,
+ // otherwise,
+ // assume we are
+ // running in
+ // local mode and ignore this
+ SystemsUtils.setHadoopConfigHome(yarnConfHome);
+
+ String filename = config.get(SamzaConfigFactory.FILE_KEY);
+ String filesystem = config.get(SamzaConfigFactory.FILESYSTEM_KEY);
+ String name = config.get(SamzaConfigFactory.JOB_NAME_KEY);
+ SerializationProxy wrapper = (SerializationProxy) SystemsUtils.deserializeObjectFromFileAndKey(filesystem,
+ filename, name);
+
+ this.entranceProcessor = wrapper.processor;
+ this.entranceProcessor.onCreate(0);
+
+ // Internal stream from SystemConsumer to EntranceTask, so we
+ // need only one partition
+ this.systemStreamPartition = new SystemStreamPartition(systemName, wrapper.name, new Partition(0));
+ }
+
+ @Override
+ public void start() {
+ Thread processorPollingThread = new Thread(
+ new Runnable() {
+ @Override
+ public void run() {
+ try {
+ pollingEntranceProcessor();
+ setIsAtHead(systemStreamPartition, true);
+ } catch (InterruptedException e) {
+ e.getStackTrace();
+ stop();
+ }
+ }
+ }
+ );
+
+ processorPollingThread.start();
+ }
+
+ @Override
+ public void stop() {
+
+ }
+
+ private void pollingEntranceProcessor() throws InterruptedException {
+ int messageCnt = 0;
+ while (!this.entranceProcessor.isFinished()) {
+ messageCnt = this.getNumMessagesInQueue(systemStreamPartition);
+ if (this.entranceProcessor.hasNext() && messageCnt < 10000) { // soft
+ // limit
+ // on the
+ // size of
+ // the
+ // queue
+ this.put(systemStreamPartition, new IncomingMessageEnvelope(systemStreamPartition, null, null,
+ this.entranceProcessor.nextEvent()));
+ } else {
+ try {
+ Thread.sleep(100);
+ } catch (InterruptedException e) {
+ break;
+ }
+ }
+ }
+
+ // Send last event
+ this.put(systemStreamPartition, new IncomingMessageEnvelope(systemStreamPartition, null, null,
+ this.entranceProcessor.nextEvent()));
+ }
+
+ }
}
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaProcessingItem.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaProcessingItem.java
index db72e7c..7c97e65 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaProcessingItem.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaProcessingItem.java
@@ -46,120 +46,127 @@
import org.apache.samza.task.TaskCoordinator;
/**
- * ProcessingItem for Samza
- * which is also a Samza task (StreamTask and InitableTask)
+ * ProcessingItem for Samza which is also a Samza task (StreamTask and
+ * InitableTask)
*
* @author Anh Thu Vu
*/
-public class SamzaProcessingItem extends AbstractProcessingItem
- implements SamzaProcessingNode, Serializable, StreamTask, InitableTask {
-
- /**
+public class SamzaProcessingItem extends AbstractProcessingItem
+ implements SamzaProcessingNode, Serializable, StreamTask, InitableTask {
+
+ /**
*
*/
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- private Set<SamzaSystemStream> inputStreams; // input streams: system.stream
- private List<SamzaStream> outputStreams;
-
- /*
- * Constructors
- */
- // Need this so Samza can initialize a StreamTask
- public SamzaProcessingItem() {}
-
- /*
- * Implement com.yahoo.labs.samoa.topology.ProcessingItem
- */
- public SamzaProcessingItem(Processor processor, int parallelismHint) {
- super(processor, parallelismHint);
- this.inputStreams = new HashSet<SamzaSystemStream>();
- this.outputStreams = new LinkedList<SamzaStream>();
- }
-
- /*
- * Simple setters, getters
- */
- public Set<SamzaSystemStream> getInputStreams() {
- return this.inputStreams;
- }
-
- /*
- * Extends AbstractProcessingItem
- */
- @Override
- protected ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme) {
- SamzaSystemStream stream = ((SamzaStream) inputStream).addDestination(new StreamDestination(this,this.getParallelism(),scheme));
- this.inputStreams.add(stream);
- return this;
- }
+ private Set<SamzaSystemStream> inputStreams; // input streams: system.stream
+ private List<SamzaStream> outputStreams;
- /*
- * Implement com.yahoo.samoa.topology.impl.SamzaProcessingNode
- */
- @Override
- public int addOutputStream(SamzaStream stream) {
- this.outputStreams.add(stream);
- return this.outputStreams.size();
- }
-
- public List<SamzaStream> getOutputStreams() {
- return this.outputStreams;
- }
+ /*
+ * Constructors
+ */
+ // Need this so Samza can initialize a StreamTask
+ public SamzaProcessingItem() {
+ }
- /*
- * Implement Samza task
- */
- @Override
- public void init(Config config, TaskContext context) throws Exception {
- String yarnConfHome = config.get(SamzaConfigFactory.YARN_CONF_HOME_KEY);
- if (yarnConfHome != null && yarnConfHome.length() > 0) // if the property is set , otherwise, assume we are running in
- // local mode and ignore this
- SystemsUtils.setHadoopConfigHome(yarnConfHome);
-
- String filename = config.get(SamzaConfigFactory.FILE_KEY);
- String filesystem = config.get(SamzaConfigFactory.FILESYSTEM_KEY);
- this.setName(config.get(SamzaConfigFactory.JOB_NAME_KEY));
- SerializationProxy wrapper = (SerializationProxy) SystemsUtils.deserializeObjectFromFileAndKey(filesystem, filename, this.getName());
- this.setProcessor(wrapper.processor);
- this.outputStreams = wrapper.outputStreams;
-
- // Init Processor and Streams
- this.getProcessor().onCreate(0);
- for (SamzaStream stream:this.outputStreams) {
- stream.onCreate();
- }
-
- }
+ /*
+ * Implement com.yahoo.labs.samoa.topology.ProcessingItem
+ */
+ public SamzaProcessingItem(Processor processor, int parallelismHint) {
+ super(processor, parallelismHint);
+ this.inputStreams = new HashSet<SamzaSystemStream>();
+ this.outputStreams = new LinkedList<SamzaStream>();
+ }
- @Override
- public void process(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator) throws Exception {
- for (SamzaStream stream:this.outputStreams) {
- stream.setCollector(collector);
- }
- this.getProcessor().process((ContentEvent) envelope.getMessage());
- }
-
- /*
- * SerializationProxy
- */
- private Object writeReplace() {
- return new SerializationProxy(this);
- }
-
- private static class SerializationProxy implements Serializable {
- /**
+ /*
+ * Simple setters, getters
+ */
+ public Set<SamzaSystemStream> getInputStreams() {
+ return this.inputStreams;
+ }
+
+ /*
+ * Extends AbstractProcessingItem
+ */
+ @Override
+ protected ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme) {
+ SamzaSystemStream stream = ((SamzaStream) inputStream).addDestination(new StreamDestination(this, this
+ .getParallelism(), scheme));
+ this.inputStreams.add(stream);
+ return this;
+ }
+
+ /*
+ * Implement com.yahoo.samoa.topology.impl.SamzaProcessingNode
+ */
+ @Override
+ public int addOutputStream(SamzaStream stream) {
+ this.outputStreams.add(stream);
+ return this.outputStreams.size();
+ }
+
+ public List<SamzaStream> getOutputStreams() {
+ return this.outputStreams;
+ }
+
+ /*
+ * Implement Samza task
+ */
+ @Override
+ public void init(Config config, TaskContext context) throws Exception {
+ String yarnConfHome = config.get(SamzaConfigFactory.YARN_CONF_HOME_KEY);
+ if (yarnConfHome != null && yarnConfHome.length() > 0) // if the property is
+ // set , otherwise,
+ // assume we are
+ // running in
+ // local mode and ignore this
+ SystemsUtils.setHadoopConfigHome(yarnConfHome);
+
+ String filename = config.get(SamzaConfigFactory.FILE_KEY);
+ String filesystem = config.get(SamzaConfigFactory.FILESYSTEM_KEY);
+ this.setName(config.get(SamzaConfigFactory.JOB_NAME_KEY));
+ SerializationProxy wrapper = (SerializationProxy) SystemsUtils.deserializeObjectFromFileAndKey(filesystem,
+ filename, this.getName());
+ this.setProcessor(wrapper.processor);
+ this.outputStreams = wrapper.outputStreams;
+
+ // Init Processor and Streams
+ this.getProcessor().onCreate(0);
+ for (SamzaStream stream : this.outputStreams) {
+ stream.onCreate();
+ }
+
+ }
+
+ @Override
+ public void process(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator)
+ throws Exception {
+ for (SamzaStream stream : this.outputStreams) {
+ stream.setCollector(collector);
+ }
+ this.getProcessor().process((ContentEvent) envelope.getMessage());
+ }
+
+ /*
+ * SerializationProxy
+ */
+ private Object writeReplace() {
+ return new SerializationProxy(this);
+ }
+
+ private static class SerializationProxy implements Serializable {
+ /**
*
*/
- private static final long serialVersionUID = 1534643987559070336L;
-
- private Processor processor;
- private List<SamzaStream> outputStreams;
-
- public SerializationProxy(SamzaProcessingItem pi) {
- this.processor = pi.getProcessor();
- this.outputStreams = pi.getOutputStreams();
- }
- }
+ private static final long serialVersionUID = 1534643987559070336L;
+
+ private Processor processor;
+ private List<SamzaStream> outputStreams;
+
+ public SerializationProxy(SamzaProcessingItem pi) {
+ this.processor = pi.getProcessor();
+ this.outputStreams = pi.getOutputStreams();
+ }
+ }
}
\ No newline at end of file
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaProcessingNode.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaProcessingNode.java
index be13673..1dbccb6 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaProcessingNode.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaProcessingNode.java
@@ -23,34 +23,36 @@
import com.yahoo.labs.samoa.topology.IProcessingItem;
/**
- * Common interface of SamzaEntranceProcessingItem and
- * SamzaProcessingItem
+ * Common interface of SamzaEntranceProcessingItem and SamzaProcessingItem
*
* @author Anh Thu Vu
*/
public interface SamzaProcessingNode extends IProcessingItem {
- /**
- * Registers an output stream with this processing item
- *
- * @param stream
- * the output stream
- * @return the number of output streams of this processing item
- */
- public int addOutputStream(SamzaStream stream);
-
- /**
- * Gets the name/id of this processing item
- *
- * @return the name/id of this processing item
- */
- // TODO: include getName() and setName() in IProcessingItem and/or AbstractEPI/PI
- public String getName();
-
- /**
- * Sets the name/id for this processing item
- * @param name
- * the name/id of this processing item
- */
- // TODO: include getName() and setName() in IProcessingItem and/or AbstractEPI/PI
- public void setName(String name);
+ /**
+ * Registers an output stream with this processing item
+ *
+ * @param stream
+ * the output stream
+ * @return the number of output streams of this processing item
+ */
+ public int addOutputStream(SamzaStream stream);
+
+ /**
+ * Gets the name/id of this processing item
+ *
+ * @return the name/id of this processing item
+ */
+ // TODO: include getName() and setName() in IProcessingItem and/or
+ // AbstractEPI/PI
+ public String getName();
+
+ /**
+ * Sets the name/id for this processing item
+ *
+ * @param name
+ * the name/id of this processing item
+ */
+ // TODO: include getName() and setName() in IProcessingItem and/or
+ // AbstractEPI/PI
+ public void setName(String name);
}
\ No newline at end of file
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaStream.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaStream.java
index c1bf5a2..a855e46 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaStream.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaStream.java
@@ -39,210 +39,209 @@
*
* @author Anh Thu Vu
*/
-public class SamzaStream extends AbstractStream implements Serializable {
+public class SamzaStream extends AbstractStream implements Serializable {
- /**
+ /**
*
*/
- private static final long serialVersionUID = 1L;
+ private static final long serialVersionUID = 1L;
- private static final String DEFAULT_SYSTEM_NAME = "kafka";
-
- private List<SamzaSystemStream> systemStreams;
- private transient MessageCollector collector;
- private String systemName;
-
- /*
- * Constructor
- */
- public SamzaStream(IProcessingItem sourcePi) {
- super(sourcePi);
- this.systemName = DEFAULT_SYSTEM_NAME;
- // Get name/id for this stream
- SamzaProcessingNode samzaPi = (SamzaProcessingNode) sourcePi;
- int index = samzaPi.addOutputStream(this);
- this.setStreamId(samzaPi.getName()+"-"+Integer.toString(index));
- // init list of SamzaSystemStream
- systemStreams = new ArrayList<SamzaSystemStream>();
- }
-
- /*
- * System name (Kafka)
- */
- public void setSystemName(String systemName) {
- this.systemName = systemName;
- for (SamzaSystemStream systemStream:systemStreams) {
- systemStream.setSystem(systemName);
- }
- }
+ private static final String DEFAULT_SYSTEM_NAME = "kafka";
- public String getSystemName() {
- return this.systemName;
- }
+ private List<SamzaSystemStream> systemStreams;
+ private transient MessageCollector collector;
+ private String systemName;
- /*
- * Add the PI to the list of destinations.
- * Return the name of the corresponding SystemStream.
- */
- public SamzaSystemStream addDestination(StreamDestination destination) {
- PartitioningScheme scheme = destination.getPartitioningScheme();
- int parallelism = destination.getParallelism();
-
- SamzaSystemStream resultStream = null;
- for (int i=0; i<systemStreams.size(); i++) {
- // There is an existing SystemStream that matches the settings.
- // Do not create a new one
- if (systemStreams.get(i).isSame(scheme, parallelism)) {
- resultStream = systemStreams.get(i);
- }
- }
-
- // No existing SystemStream match the requirement
- // Create a new one
- if (resultStream == null) {
- String topicName = this.getStreamId() + "-" + Integer.toString(systemStreams.size());
- resultStream = new SamzaSystemStream(this.systemName,topicName,scheme,parallelism);
- systemStreams.add(resultStream);
- }
-
- return resultStream;
- }
-
- public void setCollector(MessageCollector collector) {
- this.collector = collector;
- }
-
- public MessageCollector getCollector(){
- return this.collector;
- }
-
- public void onCreate() {
- for (SamzaSystemStream stream:systemStreams) {
- stream.initSystemStream();
- }
- }
-
- /*
- * Implement Stream interface
- */
- @Override
- public void put(ContentEvent event) {
- for (SamzaSystemStream stream:systemStreams) {
- stream.send(collector,event);
- }
- }
-
- public List<SamzaSystemStream> getSystemStreams() {
- return this.systemStreams;
- }
-
- /**
- * SamzaSystemStream wrap around a Samza's SystemStream
- * It contains the info to create a Samza stream during the
- * constructing process of the topology and
- * will create the actual Samza stream when the topology is submitted
- * (invoking initSystemStream())
- *
- * @author Anh Thu Vu
- */
- public static class SamzaSystemStream implements Serializable {
- /**
+ /*
+ * Constructor
+ */
+ public SamzaStream(IProcessingItem sourcePi) {
+ super(sourcePi);
+ this.systemName = DEFAULT_SYSTEM_NAME;
+ // Get name/id for this stream
+ SamzaProcessingNode samzaPi = (SamzaProcessingNode) sourcePi;
+ int index = samzaPi.addOutputStream(this);
+ this.setStreamId(samzaPi.getName() + "-" + Integer.toString(index));
+ // init list of SamzaSystemStream
+ systemStreams = new ArrayList<SamzaSystemStream>();
+ }
+
+ /*
+ * System name (Kafka)
+ */
+ public void setSystemName(String systemName) {
+ this.systemName = systemName;
+ for (SamzaSystemStream systemStream : systemStreams) {
+ systemStream.setSystem(systemName);
+ }
+ }
+
+ public String getSystemName() {
+ return this.systemName;
+ }
+
+ /*
+ * Add the PI to the list of destinations. Return the name of the
+ * corresponding SystemStream.
+ */
+ public SamzaSystemStream addDestination(StreamDestination destination) {
+ PartitioningScheme scheme = destination.getPartitioningScheme();
+ int parallelism = destination.getParallelism();
+
+ SamzaSystemStream resultStream = null;
+ for (int i = 0; i < systemStreams.size(); i++) {
+ // There is an existing SystemStream that matches the settings.
+ // Do not create a new one
+ if (systemStreams.get(i).isSame(scheme, parallelism)) {
+ resultStream = systemStreams.get(i);
+ }
+ }
+
+ // No existing SystemStream match the requirement
+ // Create a new one
+ if (resultStream == null) {
+ String topicName = this.getStreamId() + "-" + Integer.toString(systemStreams.size());
+ resultStream = new SamzaSystemStream(this.systemName, topicName, scheme, parallelism);
+ systemStreams.add(resultStream);
+ }
+
+ return resultStream;
+ }
+
+ public void setCollector(MessageCollector collector) {
+ this.collector = collector;
+ }
+
+ public MessageCollector getCollector() {
+ return this.collector;
+ }
+
+ public void onCreate() {
+ for (SamzaSystemStream stream : systemStreams) {
+ stream.initSystemStream();
+ }
+ }
+
+ /*
+ * Implement Stream interface
+ */
+ @Override
+ public void put(ContentEvent event) {
+ for (SamzaSystemStream stream : systemStreams) {
+ stream.send(collector, event);
+ }
+ }
+
+ public List<SamzaSystemStream> getSystemStreams() {
+ return this.systemStreams;
+ }
+
+ /**
+ * SamzaSystemStream wrap around a Samza's SystemStream It contains the info
+ * to create a Samza stream during the constructing process of the topology
+ * and will create the actual Samza stream when the topology is submitted
+ * (invoking initSystemStream())
+ *
+ * @author Anh Thu Vu
+ */
+ public static class SamzaSystemStream implements Serializable {
+ /**
*
*/
- private static final long serialVersionUID = 1L;
- private String system;
- private String stream;
- private PartitioningScheme scheme;
- private int parallelism;
-
- private transient SystemStream actualSystemStream = null;
-
- /*
- * Constructors
- */
- public SamzaSystemStream(String system, String stream, PartitioningScheme scheme, int parallelism) {
- this.system = system;
- this.stream = stream;
- this.scheme = scheme;
- this.parallelism = parallelism;
- }
-
- public SamzaSystemStream(String system, String stream, PartitioningScheme scheme) {
- this(system, stream, scheme, 1);
- }
-
- /*
- * Setters
- */
- public void setSystem(String system) {
- this.system = system;
- }
-
- /*
- * Getters
- */
- public String getSystem() {
- return this.system;
- }
-
- public String getStream() {
- return this.stream;
- }
-
- public PartitioningScheme getPartitioningScheme() {
- return this.scheme;
- }
-
- public int getParallelism() {
- return this.parallelism;
- }
+ private static final long serialVersionUID = 1L;
+ private String system;
+ private String stream;
+ private PartitioningScheme scheme;
+ private int parallelism;
- public boolean isSame(PartitioningScheme scheme, int parallelismHint) {
- return (this.scheme == scheme && this.parallelism == parallelismHint);
- }
-
- /*
- * Init the actual Samza stream
- */
- public void initSystemStream() {
- actualSystemStream = new SystemStream(this.system, this.stream);
- }
-
- /*
- * Send a ContentEvent
- */
- public void send(MessageCollector collector, ContentEvent contentEvent) {
- if (actualSystemStream == null)
- this.initSystemStream();
-
- switch(this.scheme) {
- case SHUFFLE:
- this.sendShuffle(collector, contentEvent);
- break;
- case GROUP_BY_KEY:
- this.sendGroupByKey(collector, contentEvent);
- break;
- case BROADCAST:
- this.sendBroadcast(collector, contentEvent);
- break;
- }
- }
-
- /*
- * Helpers
- */
- private synchronized void sendShuffle(MessageCollector collector, ContentEvent event) {
- collector.send(new OutgoingMessageEnvelope(this.actualSystemStream, event));
- }
-
- private void sendGroupByKey(MessageCollector collector, ContentEvent event) {
- collector.send(new OutgoingMessageEnvelope(this.actualSystemStream, event.getKey(), null, event));
- }
+ private transient SystemStream actualSystemStream = null;
- private synchronized void sendBroadcast(MessageCollector collector, ContentEvent event) {
- for (int i=0; i<parallelism; i++) {
- collector.send(new OutgoingMessageEnvelope(this.actualSystemStream, i, null, event));
- }
- }
- }
+ /*
+ * Constructors
+ */
+ public SamzaSystemStream(String system, String stream, PartitioningScheme scheme, int parallelism) {
+ this.system = system;
+ this.stream = stream;
+ this.scheme = scheme;
+ this.parallelism = parallelism;
+ }
+
+ public SamzaSystemStream(String system, String stream, PartitioningScheme scheme) {
+ this(system, stream, scheme, 1);
+ }
+
+ /*
+ * Setters
+ */
+ public void setSystem(String system) {
+ this.system = system;
+ }
+
+ /*
+ * Getters
+ */
+ public String getSystem() {
+ return this.system;
+ }
+
+ public String getStream() {
+ return this.stream;
+ }
+
+ public PartitioningScheme getPartitioningScheme() {
+ return this.scheme;
+ }
+
+ public int getParallelism() {
+ return this.parallelism;
+ }
+
+ public boolean isSame(PartitioningScheme scheme, int parallelismHint) {
+ return (this.scheme == scheme && this.parallelism == parallelismHint);
+ }
+
+ /*
+ * Init the actual Samza stream
+ */
+ public void initSystemStream() {
+ actualSystemStream = new SystemStream(this.system, this.stream);
+ }
+
+ /*
+ * Send a ContentEvent
+ */
+ public void send(MessageCollector collector, ContentEvent contentEvent) {
+ if (actualSystemStream == null)
+ this.initSystemStream();
+
+ switch (this.scheme) {
+ case SHUFFLE:
+ this.sendShuffle(collector, contentEvent);
+ break;
+ case GROUP_BY_KEY:
+ this.sendGroupByKey(collector, contentEvent);
+ break;
+ case BROADCAST:
+ this.sendBroadcast(collector, contentEvent);
+ break;
+ }
+ }
+
+ /*
+ * Helpers
+ */
+ private synchronized void sendShuffle(MessageCollector collector, ContentEvent event) {
+ collector.send(new OutgoingMessageEnvelope(this.actualSystemStream, event));
+ }
+
+ private void sendGroupByKey(MessageCollector collector, ContentEvent event) {
+ collector.send(new OutgoingMessageEnvelope(this.actualSystemStream, event.getKey(), null, event));
+ }
+
+ private synchronized void sendBroadcast(MessageCollector collector, ContentEvent event) {
+ for (int i = 0; i < parallelism; i++) {
+ collector.send(new OutgoingMessageEnvelope(this.actualSystemStream, i, null, event));
+ }
+ }
+ }
}
\ No newline at end of file
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaTopology.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaTopology.java
index a169bc2..4e52966 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaTopology.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/topology/impl/SamzaTopology.java
@@ -32,33 +32,33 @@
* @author Anh Thu Vu
*/
public class SamzaTopology extends AbstractTopology {
- private int procItemCounter;
-
- public SamzaTopology(String topoName) {
- super(topoName);
- procItemCounter = 0;
- }
-
- @Override
- public void addProcessingItem(IProcessingItem procItem, int parallelism) {
- super.addProcessingItem(procItem, parallelism);
- SamzaProcessingNode samzaPi = (SamzaProcessingNode) procItem;
- samzaPi.setName(this.getTopologyName()+"-"+Integer.toString(procItemCounter));
- procItemCounter++;
- }
-
- /*
- * Gets the set of ProcessingItems, excluding EntrancePIs
- * Used by SamzaConfigFactory as the config for EntrancePIs and
- * normal PIs are different
- */
- public Set<IProcessingItem> getNonEntranceProcessingItems() throws Exception {
- Set<IProcessingItem> copiedSet = new HashSet<IProcessingItem>();
- copiedSet.addAll(this.getProcessingItems());
- boolean result = copiedSet.removeAll(this.getEntranceProcessingItems());
- if (!result) {
- throw new Exception("Failed extracting the set of non-entrance processing items");
- }
- return copiedSet;
- }
+ private int procItemCounter;
+
+ public SamzaTopology(String topoName) {
+ super(topoName);
+ procItemCounter = 0;
+ }
+
+ @Override
+ public void addProcessingItem(IProcessingItem procItem, int parallelism) {
+ super.addProcessingItem(procItem, parallelism);
+ SamzaProcessingNode samzaPi = (SamzaProcessingNode) procItem;
+ samzaPi.setName(this.getTopologyName() + "-" + Integer.toString(procItemCounter));
+ procItemCounter++;
+ }
+
+ /*
+ * Gets the set of ProcessingItems, excluding EntrancePIs Used by
+ * SamzaConfigFactory as the config for EntrancePIs and normal PIs are
+ * different
+ */
+ public Set<IProcessingItem> getNonEntranceProcessingItems() throws Exception {
+ Set<IProcessingItem> copiedSet = new HashSet<IProcessingItem>();
+ copiedSet.addAll(this.getProcessingItems());
+ boolean result = copiedSet.removeAll(this.getEntranceProcessingItems());
+ if (!result) {
+ throw new Exception("Failed extracting the set of non-entrance processing items");
+ }
+ return copiedSet;
+ }
}
\ No newline at end of file
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SamzaConfigFactory.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SamzaConfigFactory.java
index 56427d0..03f35f1 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SamzaConfigFactory.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SamzaConfigFactory.java
@@ -51,482 +51,489 @@
import com.yahoo.labs.samoa.topology.impl.SamzaStream.SamzaSystemStream;
/**
- * Generate Configs that will be used to submit Samza jobs
- * from the input topology (one config per PI/EntrancePI in
- * the topology)
+ * Generate Configs that will be used to submit Samza jobs from the input
+ * topology (one config per PI/EntrancePI in the topology)
*
* @author Anh Thu Vu
- *
+ *
*/
public class SamzaConfigFactory {
- public static final String SYSTEM_NAME = "samoa";
-
- // DEFAULT VALUES
- private static final String DEFAULT_ZOOKEEPER = "localhost:2181";
- private static final String DEFAULT_BROKER_LIST = "localhost:9092";
+ public static final String SYSTEM_NAME = "samoa";
- // DELIMINATORS
- public static final String COMMA = ",";
- public static final String COLON = ":";
- public static final String DOT = ".";
- public static final char DOLLAR_SIGN = '$';
- public static final char QUESTION_MARK = '?';
-
- // PARTITIONING SCHEMES
- public static final String SHUFFLE = "shuffle";
- public static final String KEY = "key";
- public static final String BROADCAST = "broadcast";
+ // DEFAULT VALUES
+ private static final String DEFAULT_ZOOKEEPER = "localhost:2181";
+ private static final String DEFAULT_BROKER_LIST = "localhost:9092";
- // PROPERTY KEYS
- // JOB
- public static final String JOB_FACTORY_CLASS_KEY = "job.factory.class";
- public static final String JOB_NAME_KEY = "job.name";
- // YARN
- public static final String YARN_PACKAGE_KEY = "yarn.package.path";
- public static final String CONTAINER_MEMORY_KEY = "yarn.container.memory.mb";
- public static final String AM_MEMORY_KEY = "yarn.am.container.memory.mb";
- public static final String CONTAINER_COUNT_KEY = "yarn.container.count";
- // TASK (SAMZA original)
- public static final String TASK_CLASS_KEY = "task.class";
- public static final String TASK_INPUTS_KEY = "task.inputs";
- // TASK (extra)
- public static final String FILE_KEY = "task.processor.file";
- public static final String FILESYSTEM_KEY = "task.processor.filesystem";
- public static final String ENTRANCE_INPUT_KEY = "task.entrance.input";
- public static final String ENTRANCE_OUTPUT_KEY = "task.entrance.outputs";
- public static final String YARN_CONF_HOME_KEY = "yarn.config.home";
- // KAFKA
- public static final String ZOOKEEPER_URI_KEY = "consumer.zookeeper.connect";
- public static final String BROKER_URI_KEY = "producer.metadata.broker.list";
- public static final String KAFKA_BATCHSIZE_KEY = "producer.batch.num.messages";
- public static final String KAFKA_PRODUCER_TYPE_KEY = "producer.producer.type";
- // SERDE
- public static final String SERDE_REGISTRATION_KEY = "kryo.register";
+ // DELIMINATORS
+ public static final String COMMA = ",";
+ public static final String COLON = ":";
+ public static final String DOT = ".";
+ public static final char DOLLAR_SIGN = '$';
+ public static final char QUESTION_MARK = '?';
- // Instance variables
- private boolean isLocalMode;
- private String zookeeper;
- private String kafkaBrokerList;
- private int replicationFactor;
- private int amMemory;
- private int containerMemory;
- private int piPerContainerRatio;
- private int checkpointFrequency; // in ms
+ // PARTITIONING SCHEMES
+ public static final String SHUFFLE = "shuffle";
+ public static final String KEY = "key";
+ public static final String BROADCAST = "broadcast";
- private String jarPath;
- private String kryoRegisterFile = null;
+ // PROPERTY KEYS
+ // JOB
+ public static final String JOB_FACTORY_CLASS_KEY = "job.factory.class";
+ public static final String JOB_NAME_KEY = "job.name";
+ // YARN
+ public static final String YARN_PACKAGE_KEY = "yarn.package.path";
+ public static final String CONTAINER_MEMORY_KEY = "yarn.container.memory.mb";
+ public static final String AM_MEMORY_KEY = "yarn.am.container.memory.mb";
+ public static final String CONTAINER_COUNT_KEY = "yarn.container.count";
+ // TASK (SAMZA original)
+ public static final String TASK_CLASS_KEY = "task.class";
+ public static final String TASK_INPUTS_KEY = "task.inputs";
+ // TASK (extra)
+ public static final String FILE_KEY = "task.processor.file";
+ public static final String FILESYSTEM_KEY = "task.processor.filesystem";
+ public static final String ENTRANCE_INPUT_KEY = "task.entrance.input";
+ public static final String ENTRANCE_OUTPUT_KEY = "task.entrance.outputs";
+ public static final String YARN_CONF_HOME_KEY = "yarn.config.home";
+ // KAFKA
+ public static final String ZOOKEEPER_URI_KEY = "consumer.zookeeper.connect";
+ public static final String BROKER_URI_KEY = "producer.metadata.broker.list";
+ public static final String KAFKA_BATCHSIZE_KEY = "producer.batch.num.messages";
+ public static final String KAFKA_PRODUCER_TYPE_KEY = "producer.producer.type";
+ // SERDE
+ public static final String SERDE_REGISTRATION_KEY = "kryo.register";
- public SamzaConfigFactory() {
- this.isLocalMode = false;
- this.zookeeper = DEFAULT_ZOOKEEPER;
- this.kafkaBrokerList = DEFAULT_BROKER_LIST;
- this.checkpointFrequency = 60000; // default: 1 minute
- this.replicationFactor = 1;
- }
+ // Instance variables
+ private boolean isLocalMode;
+ private String zookeeper;
+ private String kafkaBrokerList;
+ private int replicationFactor;
+ private int amMemory;
+ private int containerMemory;
+ private int piPerContainerRatio;
+ private int checkpointFrequency; // in ms
- /*
- * Setter methods
- */
- public SamzaConfigFactory setYarnPackage(String packagePath) {
- this.jarPath = packagePath;
- return this;
- }
+ private String jarPath;
+ private String kryoRegisterFile = null;
- public SamzaConfigFactory setLocalMode(boolean isLocal) {
- this.isLocalMode = isLocal;
- return this;
- }
+ public SamzaConfigFactory() {
+ this.isLocalMode = false;
+ this.zookeeper = DEFAULT_ZOOKEEPER;
+ this.kafkaBrokerList = DEFAULT_BROKER_LIST;
+ this.checkpointFrequency = 60000; // default: 1 minute
+ this.replicationFactor = 1;
+ }
- public SamzaConfigFactory setZookeeper(String zk) {
- this.zookeeper = zk;
- return this;
- }
+ /*
+ * Setter methods
+ */
+ public SamzaConfigFactory setYarnPackage(String packagePath) {
+ this.jarPath = packagePath;
+ return this;
+ }
- public SamzaConfigFactory setKafka(String brokerList) {
- this.kafkaBrokerList = brokerList;
- return this;
- }
-
- public SamzaConfigFactory setCheckpointFrequency(int freq) {
- this.checkpointFrequency = freq;
- return this;
- }
-
- public SamzaConfigFactory setReplicationFactor(int replicationFactor) {
- this.replicationFactor = replicationFactor;
- return this;
- }
+ public SamzaConfigFactory setLocalMode(boolean isLocal) {
+ this.isLocalMode = isLocal;
+ return this;
+ }
- public SamzaConfigFactory setAMMemory(int mem) {
- this.amMemory = mem;
- return this;
- }
+ public SamzaConfigFactory setZookeeper(String zk) {
+ this.zookeeper = zk;
+ return this;
+ }
- public SamzaConfigFactory setContainerMemory(int mem) {
- this.containerMemory = mem;
- return this;
- }
-
- public SamzaConfigFactory setPiPerContainerRatio(int piPerContainer) {
- this.piPerContainerRatio = piPerContainer;
- return this;
- }
+ public SamzaConfigFactory setKafka(String brokerList) {
+ this.kafkaBrokerList = brokerList;
+ return this;
+ }
- public SamzaConfigFactory setKryoRegisterFile(String kryoRegister) {
- this.kryoRegisterFile = kryoRegister;
- return this;
- }
+ public SamzaConfigFactory setCheckpointFrequency(int freq) {
+ this.checkpointFrequency = freq;
+ return this;
+ }
- /*
- * Generate a map of all config properties for the input SamzaProcessingItem
- */
- private Map<String,String> getMapForPI(SamzaProcessingItem pi, String filename, String filesystem) throws Exception {
- Map<String,String> map = getBasicSystemConfig();
+ public SamzaConfigFactory setReplicationFactor(int replicationFactor) {
+ this.replicationFactor = replicationFactor;
+ return this;
+ }
- // Set job name, task class, task inputs (from SamzaProcessingItem)
- setJobName(map, pi.getName());
- setTaskClass(map, SamzaProcessingItem.class.getName());
+ public SamzaConfigFactory setAMMemory(int mem) {
+ this.amMemory = mem;
+ return this;
+ }
- StringBuilder streamNames = new StringBuilder();
- boolean first = true;
- for(SamzaSystemStream stream:pi.getInputStreams()) {
- if (!first) streamNames.append(COMMA);
- streamNames.append(stream.getSystem()+DOT+stream.getStream());
- if (first) first = false;
- }
- setTaskInputs(map, streamNames.toString());
+ public SamzaConfigFactory setContainerMemory(int mem) {
+ this.containerMemory = mem;
+ return this;
+ }
- // Processor file
- setFileName(map, filename);
- setFileSystem(map, filesystem);
-
- List<String> nameList = new ArrayList<String>();
- // Default kafka system: kafka0: sync producer
- // This system is always required: it is used for checkpointing
- nameList.add("kafka0");
- setKafkaSystem(map, "kafka0", this.zookeeper, this.kafkaBrokerList, 1);
- // Output streams: set kafka systems
- for (SamzaStream stream:pi.getOutputStreams()) {
- boolean found = false;
- for (String name:nameList) {
- if (stream.getSystemName().equals(name)) {
- found = true;
- break;
- }
- }
- if (!found) {
- nameList.add(stream.getSystemName());
- setKafkaSystem(map, stream.getSystemName(), this.zookeeper, this.kafkaBrokerList, stream.getBatchSize());
- }
- }
- // Input streams: set kafka systems
- for (SamzaSystemStream stream:pi.getInputStreams()) {
- boolean found = false;
- for (String name:nameList) {
- if (stream.getSystem().equals(name)) {
- found = true;
- break;
- }
- }
- if (!found) {
- nameList.add(stream.getSystem());
- setKafkaSystem(map, stream.getSystem(), this.zookeeper, this.kafkaBrokerList, 1);
- }
- }
-
- // Checkpointing
- setValue(map,"task.checkpoint.factory","org.apache.samza.checkpoint.kafka.KafkaCheckpointManagerFactory");
- setValue(map,"task.checkpoint.system","kafka0");
- setValue(map,"task.commit.ms","1000");
- setValue(map,"task.checkpoint.replication.factor",Integer.toString(this.replicationFactor));
-
- // Number of containers
- setNumberOfContainers(map, pi.getParallelism(), this.piPerContainerRatio);
+ public SamzaConfigFactory setPiPerContainerRatio(int piPerContainer) {
+ this.piPerContainerRatio = piPerContainer;
+ return this;
+ }
- return map;
- }
+ public SamzaConfigFactory setKryoRegisterFile(String kryoRegister) {
+ this.kryoRegisterFile = kryoRegister;
+ return this;
+ }
- /*
- * Generate a map of all config properties for the input SamzaProcessingItem
- */
- public Map<String,String> getMapForEntrancePI(SamzaEntranceProcessingItem epi, String filename, String filesystem) {
- Map<String,String> map = getBasicSystemConfig();
+ /*
+ * Generate a map of all config properties for the input SamzaProcessingItem
+ */
+ private Map<String, String> getMapForPI(SamzaProcessingItem pi, String filename, String filesystem) throws Exception {
+ Map<String, String> map = getBasicSystemConfig();
- // Set job name, task class (from SamzaEntranceProcessingItem)
- setJobName(map, epi.getName());
- setTaskClass(map, SamzaEntranceProcessingItem.class.getName());
+ // Set job name, task class, task inputs (from SamzaProcessingItem)
+ setJobName(map, pi.getName());
+ setTaskClass(map, SamzaProcessingItem.class.getName());
- // Input for the entrance task (from our custom consumer)
- setTaskInputs(map, SYSTEM_NAME+"."+epi.getName());
+ StringBuilder streamNames = new StringBuilder();
+ boolean first = true;
+ for (SamzaSystemStream stream : pi.getInputStreams()) {
+ if (!first)
+ streamNames.append(COMMA);
+ streamNames.append(stream.getSystem() + DOT + stream.getStream());
+ if (first)
+ first = false;
+ }
+ setTaskInputs(map, streamNames.toString());
- // Output from entrance task
- // Since entrancePI should have only 1 output stream
- // there is no need for checking the batch size, setting different system names
- // The custom consumer (samoa system) does not suuport reading from a specific index
- // => no need for checkpointing
- SamzaStream outputStream = (SamzaStream)epi.getOutputStream();
- // Set samoa system factory
- setValue(map, "systems."+SYSTEM_NAME+".samza.factory", SamoaSystemFactory.class.getName());
- // Set Kafka system (only if there is an output stream)
- if (outputStream != null)
- setKafkaSystem(map, outputStream.getSystemName(), this.zookeeper, this.kafkaBrokerList, outputStream.getBatchSize());
+ // Processor file
+ setFileName(map, filename);
+ setFileSystem(map, filesystem);
- // Processor file
- setFileName(map, filename);
- setFileSystem(map, filesystem);
-
- // Number of containers
- setNumberOfContainers(map, 1, this.piPerContainerRatio);
+ List<String> nameList = new ArrayList<String>();
+ // Default kafka system: kafka0: sync producer
+ // This system is always required: it is used for checkpointing
+ nameList.add("kafka0");
+ setKafkaSystem(map, "kafka0", this.zookeeper, this.kafkaBrokerList, 1);
+ // Output streams: set kafka systems
+ for (SamzaStream stream : pi.getOutputStreams()) {
+ boolean found = false;
+ for (String name : nameList) {
+ if (stream.getSystemName().equals(name)) {
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ nameList.add(stream.getSystemName());
+ setKafkaSystem(map, stream.getSystemName(), this.zookeeper, this.kafkaBrokerList, stream.getBatchSize());
+ }
+ }
+ // Input streams: set kafka systems
+ for (SamzaSystemStream stream : pi.getInputStreams()) {
+ boolean found = false;
+ for (String name : nameList) {
+ if (stream.getSystem().equals(name)) {
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ nameList.add(stream.getSystem());
+ setKafkaSystem(map, stream.getSystem(), this.zookeeper, this.kafkaBrokerList, 1);
+ }
+ }
- return map;
- }
+ // Checkpointing
+ setValue(map, "task.checkpoint.factory", "org.apache.samza.checkpoint.kafka.KafkaCheckpointManagerFactory");
+ setValue(map, "task.checkpoint.system", "kafka0");
+ setValue(map, "task.commit.ms", "1000");
+ setValue(map, "task.checkpoint.replication.factor", Integer.toString(this.replicationFactor));
- /*
- * Generate a list of map (of config properties) for all PIs and EPI in
- * the input topology
- */
- public List<Map<String,String>> getMapsForTopology(SamzaTopology topology) throws Exception {
+ // Number of containers
+ setNumberOfContainers(map, pi.getParallelism(), this.piPerContainerRatio);
- List<Map<String,String>> maps = new ArrayList<Map<String,String>>();
+ return map;
+ }
- // File to write serialized objects
- String filename = topology.getTopologyName() + ".dat";
- Path dirPath = FileSystems.getDefault().getPath("dat");
- Path filePath= FileSystems.getDefault().getPath(dirPath.toString(), filename);
- String dstPath = filePath.toString();
- String resPath;
- String filesystem;
- if (this.isLocalMode) {
- filesystem = SystemsUtils.LOCAL_FS;
- File dir = dirPath.toFile();
- if (!dir.exists())
- FileUtils.forceMkdir(dir);
- }
- else {
- filesystem = SystemsUtils.HDFS;
- }
+ /*
+ * Generate a map of all config properties for the input SamzaProcessingItem
+ */
+ public Map<String, String> getMapForEntrancePI(SamzaEntranceProcessingItem epi, String filename, String filesystem) {
+ Map<String, String> map = getBasicSystemConfig();
- // Correct system name for streams
- this.setSystemNameForStreams(topology.getStreams());
-
- // Add all PIs to a collection (map)
- Map<String,Object> piMap = new HashMap<String,Object>();
- Set<EntranceProcessingItem> entranceProcessingItems = topology.getEntranceProcessingItems();
- Set<IProcessingItem> processingItems = topology.getNonEntranceProcessingItems();
- for(EntranceProcessingItem epi:entranceProcessingItems) {
- SamzaEntranceProcessingItem sepi = (SamzaEntranceProcessingItem) epi;
- piMap.put(sepi.getName(), sepi);
- }
- for(IProcessingItem pi:processingItems) {
- SamzaProcessingItem spi = (SamzaProcessingItem) pi;
- piMap.put(spi.getName(), spi);
- }
+ // Set job name, task class (from SamzaEntranceProcessingItem)
+ setJobName(map, epi.getName());
+ setTaskClass(map, SamzaEntranceProcessingItem.class.getName());
- // Serialize all PIs
- boolean serialized = false;
- if (this.isLocalMode) {
- serialized = SystemsUtils.serializeObjectToLocalFileSystem(piMap, dstPath);
- resPath = dstPath;
- }
- else {
- resPath = SystemsUtils.serializeObjectToHDFS(piMap, dstPath);
- serialized = resPath != null;
- }
+ // Input for the entrance task (from our custom consumer)
+ setTaskInputs(map, SYSTEM_NAME + "." + epi.getName());
- if (!serialized) {
- throw new Exception("Fail serialize map of PIs to file");
- }
+ // Output from entrance task
+ // Since entrancePI should have only 1 output stream
+ // there is no need for checking the batch size, setting different system
+ // names
+ // The custom consumer (samoa system) does not suuport reading from a
+ // specific index
+ // => no need for checkpointing
+ SamzaStream outputStream = (SamzaStream) epi.getOutputStream();
+ // Set samoa system factory
+ setValue(map, "systems." + SYSTEM_NAME + ".samza.factory", SamoaSystemFactory.class.getName());
+ // Set Kafka system (only if there is an output stream)
+ if (outputStream != null)
+ setKafkaSystem(map, outputStream.getSystemName(), this.zookeeper, this.kafkaBrokerList,
+ outputStream.getBatchSize());
- // MapConfig for all PIs
- for(EntranceProcessingItem epi:entranceProcessingItems) {
- SamzaEntranceProcessingItem sepi = (SamzaEntranceProcessingItem) epi;
- maps.add(this.getMapForEntrancePI(sepi, resPath, filesystem));
- }
- for(IProcessingItem pi:processingItems) {
- SamzaProcessingItem spi = (SamzaProcessingItem) pi;
- maps.add(this.getMapForPI(spi, resPath, filesystem));
- }
+ // Processor file
+ setFileName(map, filename);
+ setFileSystem(map, filesystem);
- return maps;
- }
+ // Number of containers
+ setNumberOfContainers(map, 1, this.piPerContainerRatio);
- /**
- * Construct a list of MapConfigs for a Topology
- * @return the list of MapConfigs
- * @throws Exception
- */
- public List<MapConfig> getMapConfigsForTopology(SamzaTopology topology) throws Exception {
- List<MapConfig> configs = new ArrayList<MapConfig>();
- List<Map<String,String>> maps = this.getMapsForTopology(topology);
- for(Map<String,String> map:maps) {
- configs.add(new MapConfig(map));
- }
- return configs;
- }
-
- /*
+ return map;
+ }
+
+ /*
+ * Generate a list of map (of config properties) for all PIs and EPI in the
+ * input topology
+ */
+ public List<Map<String, String>> getMapsForTopology(SamzaTopology topology) throws Exception {
+
+ List<Map<String, String>> maps = new ArrayList<Map<String, String>>();
+
+ // File to write serialized objects
+ String filename = topology.getTopologyName() + ".dat";
+ Path dirPath = FileSystems.getDefault().getPath("dat");
+ Path filePath = FileSystems.getDefault().getPath(dirPath.toString(), filename);
+ String dstPath = filePath.toString();
+ String resPath;
+ String filesystem;
+ if (this.isLocalMode) {
+ filesystem = SystemsUtils.LOCAL_FS;
+ File dir = dirPath.toFile();
+ if (!dir.exists())
+ FileUtils.forceMkdir(dir);
+ }
+ else {
+ filesystem = SystemsUtils.HDFS;
+ }
+
+ // Correct system name for streams
+ this.setSystemNameForStreams(topology.getStreams());
+
+ // Add all PIs to a collection (map)
+ Map<String, Object> piMap = new HashMap<String, Object>();
+ Set<EntranceProcessingItem> entranceProcessingItems = topology.getEntranceProcessingItems();
+ Set<IProcessingItem> processingItems = topology.getNonEntranceProcessingItems();
+ for (EntranceProcessingItem epi : entranceProcessingItems) {
+ SamzaEntranceProcessingItem sepi = (SamzaEntranceProcessingItem) epi;
+ piMap.put(sepi.getName(), sepi);
+ }
+ for (IProcessingItem pi : processingItems) {
+ SamzaProcessingItem spi = (SamzaProcessingItem) pi;
+ piMap.put(spi.getName(), spi);
+ }
+
+ // Serialize all PIs
+ boolean serialized = false;
+ if (this.isLocalMode) {
+ serialized = SystemsUtils.serializeObjectToLocalFileSystem(piMap, dstPath);
+ resPath = dstPath;
+ }
+ else {
+ resPath = SystemsUtils.serializeObjectToHDFS(piMap, dstPath);
+ serialized = resPath != null;
+ }
+
+ if (!serialized) {
+ throw new Exception("Fail serialize map of PIs to file");
+ }
+
+ // MapConfig for all PIs
+ for (EntranceProcessingItem epi : entranceProcessingItems) {
+ SamzaEntranceProcessingItem sepi = (SamzaEntranceProcessingItem) epi;
+ maps.add(this.getMapForEntrancePI(sepi, resPath, filesystem));
+ }
+ for (IProcessingItem pi : processingItems) {
+ SamzaProcessingItem spi = (SamzaProcessingItem) pi;
+ maps.add(this.getMapForPI(spi, resPath, filesystem));
+ }
+
+ return maps;
+ }
+
+ /**
+ * Construct a list of MapConfigs for a Topology
+ *
+ * @return the list of MapConfigs
+ * @throws Exception
+ */
+ public List<MapConfig> getMapConfigsForTopology(SamzaTopology topology) throws Exception {
+ List<MapConfig> configs = new ArrayList<MapConfig>();
+ List<Map<String, String>> maps = this.getMapsForTopology(topology);
+ for (Map<String, String> map : maps) {
+ configs.add(new MapConfig(map));
+ }
+ return configs;
+ }
+
+ /*
*
*/
- public void setSystemNameForStreams(Set<Stream> streams) {
- Map<Integer, String> batchSizeMap = new HashMap<Integer, String>();
- batchSizeMap.put(1, "kafka0"); // default system with sync producer
- int counter = 0;
- for (Stream stream:streams) {
- SamzaStream samzaStream = (SamzaStream) stream;
- String systemName = batchSizeMap.get(samzaStream.getBatchSize());
- if (systemName == null) {
- counter++;
- // Add new system
- systemName = "kafka"+Integer.toString(counter);
- batchSizeMap.put(samzaStream.getBatchSize(), systemName);
- }
- samzaStream.setSystemName(systemName);
- }
+ public void setSystemNameForStreams(Set<Stream> streams) {
+ Map<Integer, String> batchSizeMap = new HashMap<Integer, String>();
+ batchSizeMap.put(1, "kafka0"); // default system with sync producer
+ int counter = 0;
+ for (Stream stream : streams) {
+ SamzaStream samzaStream = (SamzaStream) stream;
+ String systemName = batchSizeMap.get(samzaStream.getBatchSize());
+ if (systemName == null) {
+ counter++;
+ // Add new system
+ systemName = "kafka" + Integer.toString(counter);
+ batchSizeMap.put(samzaStream.getBatchSize(), systemName);
+ }
+ samzaStream.setSystemName(systemName);
+ }
-}
+ }
- /*
- * Generate a map with common properties for PIs and EPI
- */
- private Map<String,String> getBasicSystemConfig() {
- Map<String,String> map = new HashMap<String,String>();
- // Job & Task
- if (this.isLocalMode)
- map.put(JOB_FACTORY_CLASS_KEY, LocalJobFactory.class.getName());
- else {
- map.put(JOB_FACTORY_CLASS_KEY, YarnJobFactory.class.getName());
+ /*
+ * Generate a map with common properties for PIs and EPI
+ */
+ private Map<String, String> getBasicSystemConfig() {
+ Map<String, String> map = new HashMap<String, String>();
+ // Job & Task
+ if (this.isLocalMode)
+ map.put(JOB_FACTORY_CLASS_KEY, LocalJobFactory.class.getName());
+ else {
+ map.put(JOB_FACTORY_CLASS_KEY, YarnJobFactory.class.getName());
- // yarn
- map.put(YARN_PACKAGE_KEY,jarPath);
- map.put(CONTAINER_MEMORY_KEY, Integer.toString(this.containerMemory));
- map.put(AM_MEMORY_KEY, Integer.toString(this.amMemory));
- map.put(CONTAINER_COUNT_KEY, "1");
- map.put(YARN_CONF_HOME_KEY, SystemsUtils.getHadoopConfigHome());
-
- // Task opts (Heap size = 0.75 container memory)
- int heapSize = (int)(0.75*this.containerMemory);
- map.put("task.opts", "-Xmx"+Integer.toString(heapSize)+"M -XX:+PrintGCDateStamps");
- }
+ // yarn
+ map.put(YARN_PACKAGE_KEY, jarPath);
+ map.put(CONTAINER_MEMORY_KEY, Integer.toString(this.containerMemory));
+ map.put(AM_MEMORY_KEY, Integer.toString(this.amMemory));
+ map.put(CONTAINER_COUNT_KEY, "1");
+ map.put(YARN_CONF_HOME_KEY, SystemsUtils.getHadoopConfigHome());
+ // Task opts (Heap size = 0.75 container memory)
+ int heapSize = (int) (0.75 * this.containerMemory);
+ map.put("task.opts", "-Xmx" + Integer.toString(heapSize) + "M -XX:+PrintGCDateStamps");
+ }
- map.put(JOB_NAME_KEY, "");
- map.put(TASK_CLASS_KEY, "");
- map.put(TASK_INPUTS_KEY, "");
+ map.put(JOB_NAME_KEY, "");
+ map.put(TASK_CLASS_KEY, "");
+ map.put(TASK_INPUTS_KEY, "");
- // register serializer
- map.put("serializers.registry.kryo.class",SamzaKryoSerdeFactory.class.getName());
+ // register serializer
+ map.put("serializers.registry.kryo.class", SamzaKryoSerdeFactory.class.getName());
- // Serde registration
- setKryoRegistration(map, this.kryoRegisterFile);
+ // Serde registration
+ setKryoRegistration(map, this.kryoRegisterFile);
- return map;
- }
-
- /*
- * Helper methods to set different properties in the input map
- */
- private static void setJobName(Map<String,String> map, String jobName) {
- map.put(JOB_NAME_KEY, jobName);
- }
+ return map;
+ }
- private static void setFileName(Map<String,String> map, String filename) {
- map.put(FILE_KEY, filename);
- }
+ /*
+ * Helper methods to set different properties in the input map
+ */
+ private static void setJobName(Map<String, String> map, String jobName) {
+ map.put(JOB_NAME_KEY, jobName);
+ }
- private static void setFileSystem(Map<String,String> map, String filesystem) {
- map.put(FILESYSTEM_KEY, filesystem);
- }
+ private static void setFileName(Map<String, String> map, String filename) {
+ map.put(FILE_KEY, filename);
+ }
- private static void setTaskClass(Map<String,String> map, String taskClass) {
- map.put(TASK_CLASS_KEY, taskClass);
- }
+ private static void setFileSystem(Map<String, String> map, String filesystem) {
+ map.put(FILESYSTEM_KEY, filesystem);
+ }
- private static void setTaskInputs(Map<String,String> map, String inputs) {
- map.put(TASK_INPUTS_KEY, inputs);
- }
+ private static void setTaskClass(Map<String, String> map, String taskClass) {
+ map.put(TASK_CLASS_KEY, taskClass);
+ }
- private static void setKryoRegistration(Map<String, String> map, String kryoRegisterFile) {
- if (kryoRegisterFile != null) {
- String value = readKryoRegistration(kryoRegisterFile);
- map.put(SERDE_REGISTRATION_KEY, value);
- }
- }
-
- private static void setNumberOfContainers(Map<String, String> map, int parallelism, int piPerContainer) {
- int res = parallelism / piPerContainer;
- if (parallelism % piPerContainer != 0) res++;
- map.put(CONTAINER_COUNT_KEY, Integer.toString(res));
- }
-
- private static void setKafkaSystem(Map<String,String> map, String systemName, String zk, String brokers, int batchSize) {
- map.put("systems."+systemName+".samza.factory",KafkaSystemFactory.class.getName());
- map.put("systems."+systemName+".samza.msg.serde","kryo");
+ private static void setTaskInputs(Map<String, String> map, String inputs) {
+ map.put(TASK_INPUTS_KEY, inputs);
+ }
- map.put("systems."+systemName+"."+ZOOKEEPER_URI_KEY,zk);
- map.put("systems."+systemName+"."+BROKER_URI_KEY,brokers);
- map.put("systems."+systemName+"."+KAFKA_BATCHSIZE_KEY,Integer.toString(batchSize));
+ private static void setKryoRegistration(Map<String, String> map, String kryoRegisterFile) {
+ if (kryoRegisterFile != null) {
+ String value = readKryoRegistration(kryoRegisterFile);
+ map.put(SERDE_REGISTRATION_KEY, value);
+ }
+ }
- map.put("systems."+systemName+".samza.offset.default","oldest");
+ private static void setNumberOfContainers(Map<String, String> map, int parallelism, int piPerContainer) {
+ int res = parallelism / piPerContainer;
+ if (parallelism % piPerContainer != 0)
+ res++;
+ map.put(CONTAINER_COUNT_KEY, Integer.toString(res));
+ }
- if (batchSize > 1) {
- map.put("systems."+systemName+"."+KAFKA_PRODUCER_TYPE_KEY,"async");
- }
- else {
- map.put("systems."+systemName+"."+KAFKA_PRODUCER_TYPE_KEY,"sync");
- }
- }
-
- // Set custom properties
- private static void setValue(Map<String,String> map, String key, String value) {
- map.put(key,value);
- }
+ private static void setKafkaSystem(Map<String, String> map, String systemName, String zk, String brokers,
+ int batchSize) {
+ map.put("systems." + systemName + ".samza.factory", KafkaSystemFactory.class.getName());
+ map.put("systems." + systemName + ".samza.msg.serde", "kryo");
- /*
- * Helper method to parse Kryo registration file
- */
- private static String readKryoRegistration(String filePath) {
- FileInputStream fis = null;
- Properties props = new Properties();
- StringBuilder result = new StringBuilder();
- try {
- fis = new FileInputStream(filePath);
- props.load(fis);
+ map.put("systems." + systemName + "." + ZOOKEEPER_URI_KEY, zk);
+ map.put("systems." + systemName + "." + BROKER_URI_KEY, brokers);
+ map.put("systems." + systemName + "." + KAFKA_BATCHSIZE_KEY, Integer.toString(batchSize));
- boolean first = true;
- String value = null;
- for(String k:props.stringPropertyNames()) {
- if (!first)
- result.append(COMMA);
- else
- first = false;
-
- // Need to avoid the dollar sign as samza pass all the properties in
- // the config to containers via commandline parameters/enviroment variables
- // We might escape the dollar sign, but it's more complicated than
- // replacing it with something else
- result.append(k.trim().replace(DOLLAR_SIGN, QUESTION_MARK));
- value = props.getProperty(k);
- if (value != null && value.trim().length() > 0) {
- result.append(COLON);
- result.append(value.trim().replace(DOLLAR_SIGN, QUESTION_MARK));
- }
- }
- } catch (FileNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- } finally {
- if (fis != null)
- try {
- fis.close();
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- }
+ map.put("systems." + systemName + ".samza.offset.default", "oldest");
- return result.toString();
- }
+ if (batchSize > 1) {
+ map.put("systems." + systemName + "." + KAFKA_PRODUCER_TYPE_KEY, "async");
+ }
+ else {
+ map.put("systems." + systemName + "." + KAFKA_PRODUCER_TYPE_KEY, "sync");
+ }
+ }
+
+ // Set custom properties
+ private static void setValue(Map<String, String> map, String key, String value) {
+ map.put(key, value);
+ }
+
+ /*
+ * Helper method to parse Kryo registration file
+ */
+ private static String readKryoRegistration(String filePath) {
+ FileInputStream fis = null;
+ Properties props = new Properties();
+ StringBuilder result = new StringBuilder();
+ try {
+ fis = new FileInputStream(filePath);
+ props.load(fis);
+
+ boolean first = true;
+ String value = null;
+ for (String k : props.stringPropertyNames()) {
+ if (!first)
+ result.append(COMMA);
+ else
+ first = false;
+
+ // Need to avoid the dollar sign as samza pass all the properties in
+ // the config to containers via commandline parameters/enviroment
+ // variables
+ // We might escape the dollar sign, but it's more complicated than
+ // replacing it with something else
+ result.append(k.trim().replace(DOLLAR_SIGN, QUESTION_MARK));
+ value = props.getProperty(k);
+ if (value != null && value.trim().length() > 0) {
+ result.append(COLON);
+ result.append(value.trim().replace(DOLLAR_SIGN, QUESTION_MARK));
+ }
+ }
+ } catch (FileNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } finally {
+ if (fis != null)
+ try {
+ fis.close();
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+
+ return result.toString();
+ }
}
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SamzaKryoSerdeFactory.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SamzaKryoSerdeFactory.java
index 8e9e446..cd4b846 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SamzaKryoSerdeFactory.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SamzaKryoSerdeFactory.java
@@ -34,122 +34,126 @@
import com.esotericsoftware.kryo.io.Output;
/**
- * Implementation of Samza's SerdeFactory
- * that uses Kryo to serialize/deserialize objects
+ * Implementation of Samza's SerdeFactory that uses Kryo to
+ * serialize/deserialize objects
*
* @author Anh Thu Vu
* @param <T>
- *
+ *
*/
public class SamzaKryoSerdeFactory<T> implements SerdeFactory<T> {
-
- private static final Logger logger = LoggerFactory.getLogger(SamzaKryoSerdeFactory.class);
-
- public static class SamzaKryoSerde<V> implements Serde<V> {
- private Kryo kryo;
-
- public SamzaKryoSerde (String registrationInfo) {
- this.kryo = new Kryo();
- this.register(registrationInfo);
- }
-
- @SuppressWarnings({ "rawtypes", "unchecked" })
- private void register(String registrationInfo) {
- if (registrationInfo == null) return;
-
- String[] infoList = registrationInfo.split(SamzaConfigFactory.COMMA);
-
- Class targetClass = null;
- Class serializerClass = null;
- Serializer serializer = null;
-
- for (String info:infoList) {
- String[] fields = info.split(SamzaConfigFactory.COLON);
-
- targetClass = null;
- serializerClass = null;
- if (fields.length > 0) {
- try {
- targetClass = Class.forName(fields[0].replace(SamzaConfigFactory.QUESTION_MARK, SamzaConfigFactory.DOLLAR_SIGN));
- } catch (ClassNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- }
- if (fields.length > 1) {
- try {
- serializerClass = Class.forName(fields[1].replace(SamzaConfigFactory.QUESTION_MARK, SamzaConfigFactory.DOLLAR_SIGN));
- } catch (ClassNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- }
-
- if (targetClass != null) {
- if (serializerClass == null) {
- kryo.register(targetClass);
- }
- else {
- serializer = resolveSerializerInstance(kryo, targetClass, (Class<? extends Serializer>)serializerClass) ;
- kryo.register(targetClass, serializer);
- }
- }
- else {
- logger.info("Invalid registration info:{}",info);
- }
- }
- }
-
- @SuppressWarnings("rawtypes")
- private static Serializer resolveSerializerInstance(Kryo k, Class superClass, Class<? extends Serializer> serializerClass) {
- try {
- try {
- return serializerClass.getConstructor(Kryo.class, Class.class).newInstance(k, superClass);
- } catch (Exception ex1) {
- try {
- return serializerClass.getConstructor(Kryo.class).newInstance(k);
- } catch (Exception ex2) {
- try {
- return serializerClass.getConstructor(Class.class).newInstance(superClass);
- } catch (Exception ex3) {
- return serializerClass.newInstance();
- }
- }
- }
- } catch (Exception ex) {
- throw new IllegalArgumentException("Unable to create serializer \""
- + serializerClass.getName()
- + "\" for class: "
- + superClass.getName(), ex);
- }
- }
-
- /*
- * Implement Samza Serde interface
- */
- @Override
- public byte[] toBytes(V obj) {
- ByteArrayOutputStream bos = new ByteArrayOutputStream();
- Output output = new Output(bos);
- kryo.writeClassAndObject(output, obj);
- output.flush();
- output.close();
- return bos.toByteArray();
- }
- @SuppressWarnings("unchecked")
- @Override
- public V fromBytes(byte[] byteArr) {
- Input input = new Input(byteArr);
- Object obj = kryo.readClassAndObject(input);
- input.close();
- return (V)obj;
- }
-
- }
+ private static final Logger logger = LoggerFactory.getLogger(SamzaKryoSerdeFactory.class);
- @Override
- public Serde<T> getSerde(String name, Config config) {
- return new SamzaKryoSerde<T>(config.get(SamzaConfigFactory.SERDE_REGISTRATION_KEY));
- }
+ public static class SamzaKryoSerde<V> implements Serde<V> {
+ private Kryo kryo;
+
+ public SamzaKryoSerde(String registrationInfo) {
+ this.kryo = new Kryo();
+ this.register(registrationInfo);
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ private void register(String registrationInfo) {
+ if (registrationInfo == null)
+ return;
+
+ String[] infoList = registrationInfo.split(SamzaConfigFactory.COMMA);
+
+ Class targetClass = null;
+ Class serializerClass = null;
+ Serializer serializer = null;
+
+ for (String info : infoList) {
+ String[] fields = info.split(SamzaConfigFactory.COLON);
+
+ targetClass = null;
+ serializerClass = null;
+ if (fields.length > 0) {
+ try {
+ targetClass = Class.forName(fields[0].replace(SamzaConfigFactory.QUESTION_MARK,
+ SamzaConfigFactory.DOLLAR_SIGN));
+ } catch (ClassNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+ if (fields.length > 1) {
+ try {
+ serializerClass = Class.forName(fields[1].replace(SamzaConfigFactory.QUESTION_MARK,
+ SamzaConfigFactory.DOLLAR_SIGN));
+ } catch (ClassNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+
+ if (targetClass != null) {
+ if (serializerClass == null) {
+ kryo.register(targetClass);
+ }
+ else {
+ serializer = resolveSerializerInstance(kryo, targetClass, (Class<? extends Serializer>) serializerClass);
+ kryo.register(targetClass, serializer);
+ }
+ }
+ else {
+ logger.info("Invalid registration info:{}", info);
+ }
+ }
+ }
+
+ @SuppressWarnings("rawtypes")
+ private static Serializer resolveSerializerInstance(Kryo k, Class superClass,
+ Class<? extends Serializer> serializerClass) {
+ try {
+ try {
+ return serializerClass.getConstructor(Kryo.class, Class.class).newInstance(k, superClass);
+ } catch (Exception ex1) {
+ try {
+ return serializerClass.getConstructor(Kryo.class).newInstance(k);
+ } catch (Exception ex2) {
+ try {
+ return serializerClass.getConstructor(Class.class).newInstance(superClass);
+ } catch (Exception ex3) {
+ return serializerClass.newInstance();
+ }
+ }
+ }
+ } catch (Exception ex) {
+ throw new IllegalArgumentException("Unable to create serializer \""
+ + serializerClass.getName()
+ + "\" for class: "
+ + superClass.getName(), ex);
+ }
+ }
+
+ /*
+ * Implement Samza Serde interface
+ */
+ @Override
+ public byte[] toBytes(V obj) {
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ Output output = new Output(bos);
+ kryo.writeClassAndObject(output, obj);
+ output.flush();
+ output.close();
+ return bos.toByteArray();
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public V fromBytes(byte[] byteArr) {
+ Input input = new Input(byteArr);
+ Object obj = kryo.readClassAndObject(input);
+ input.close();
+ return (V) obj;
+ }
+
+ }
+
+ @Override
+ public Serde<T> getSerde(String name, Config config) {
+ return new SamzaKryoSerde<T>(config.get(SamzaConfigFactory.SERDE_REGISTRATION_KEY));
+ }
}
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SerializableSerializer.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SerializableSerializer.java
index 4bdbafd..62fa3fd 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SerializableSerializer.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SerializableSerializer.java
@@ -37,34 +37,34 @@
* @author Anh Thu Vu
*/
public class SerializableSerializer extends Serializer<Object> {
- @Override
- public void write(Kryo kryo, Output output, Object object) {
- ByteArrayOutputStream bos = new ByteArrayOutputStream();
- try {
- ObjectOutputStream oos = new ObjectOutputStream(bos);
- oos.writeObject(object);
- oos.flush();
- } catch(IOException e) {
- throw new RuntimeException(e);
- }
- byte[] ser = bos.toByteArray();
- output.writeInt(ser.length);
- output.writeBytes(ser);
+ @Override
+ public void write(Kryo kryo, Output output, Object object) {
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ try {
+ ObjectOutputStream oos = new ObjectOutputStream(bos);
+ oos.writeObject(object);
+ oos.flush();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
}
-
- @SuppressWarnings("rawtypes")
- @Override
- public Object read(Kryo kryo, Input input, Class c) {
- int len = input.readInt();
- byte[] ser = new byte[len];
- input.readBytes(ser);
- ByteArrayInputStream bis = new ByteArrayInputStream(ser);
- try {
- ObjectInputStream ois = new ObjectInputStream(bis);
- return ois.readObject();
- } catch(Exception e) {
- throw new RuntimeException(e);
- }
+ byte[] ser = bos.toByteArray();
+ output.writeInt(ser.length);
+ output.writeBytes(ser);
+ }
+
+ @SuppressWarnings("rawtypes")
+ @Override
+ public Object read(Kryo kryo, Input input, Class c) {
+ int len = input.readInt();
+ byte[] ser = new byte[len];
+ input.readBytes(ser);
+ ByteArrayInputStream bis = new ByteArrayInputStream(ser);
+ try {
+ ObjectInputStream ois = new ObjectInputStream(bis);
+ return ois.readObject();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
}
+ }
}
diff --git a/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SystemsUtils.java b/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SystemsUtils.java
index 367f9f9..f8e6dcd 100644
--- a/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SystemsUtils.java
+++ b/samoa-samza/src/main/java/com/yahoo/labs/samoa/utils/SystemsUtils.java
@@ -45,339 +45,342 @@
import org.slf4j.LoggerFactory;
/**
- * Utilities methods for:
- * - Kafka
- * - HDFS
- * - Handling files on local FS
+ * Utilities methods for: - Kafka - HDFS - Handling files on local FS
*
* @author Anh Thu Vu
*/
public class SystemsUtils {
- private static final Logger logger = LoggerFactory.getLogger(SystemsUtils.class);
-
- public static final String HDFS = "hdfs";
- public static final String LOCAL_FS = "local";
-
- private static final String TEMP_FILE = "samoaTemp";
- private static final String TEMP_FILE_SUFFIX = ".dat";
-
- /*
- * Kafka
- */
- private static class KafkaUtils {
- private static ZkClient zkClient;
-
- static void setZookeeper(String zk) {
- zkClient = new ZkClient(zk, 30000, 30000, new ZKStringSerializerWrapper());
- }
+ private static final Logger logger = LoggerFactory.getLogger(SystemsUtils.class);
- /*
- * Create Kafka topic/stream
- */
- static void createKafkaTopic(String name, int partitions, int replicas) {
- AdminUtils.createTopic(zkClient, name, partitions, replicas, new Properties());
- }
-
- static class ZKStringSerializerWrapper implements ZkSerializer {
- @Override
- public Object deserialize(byte[] byteArray) throws ZkMarshallingError {
- return ZKStringSerializer.deserialize(byteArray);
- }
+ public static final String HDFS = "hdfs";
+ public static final String LOCAL_FS = "local";
- @Override
- public byte[] serialize(Object obj) throws ZkMarshallingError {
- return ZKStringSerializer.serialize(obj);
- }
- }
- }
-
- /*
- * HDFS
- */
- private static class HDFSUtils {
- private static String coreConfPath;
- private static String hdfsConfPath;
- private static String configHomePath;
- private static String samoaDir = null;
-
- static void setHadoopConfigHome(String hadoopConfPath) {
- logger.info("Hadoop config home:{}",hadoopConfPath);
- configHomePath = hadoopConfPath;
- java.nio.file.Path coreSitePath = FileSystems.getDefault().getPath(hadoopConfPath, "core-site.xml");
- java.nio.file.Path hdfsSitePath = FileSystems.getDefault().getPath(hadoopConfPath, "hdfs-site.xml");
- coreConfPath = coreSitePath.toAbsolutePath().toString();
- hdfsConfPath = hdfsSitePath.toAbsolutePath().toString();
- }
-
- static String getNameNodeUri() {
- Configuration config = new Configuration();
- config.addResource(new Path(coreConfPath));
- config.addResource(new Path(hdfsConfPath));
-
- return config.get("fs.defaultFS");
- }
-
- static String getHadoopConfigHome() {
- return configHomePath;
- }
-
- static void setSAMOADir(String dir) {
- if (dir != null)
- samoaDir = getNameNodeUri()+dir;
- else
- samoaDir = null;
- }
-
- static String getDefaultSAMOADir() throws IOException {
- Configuration config = new Configuration();
- config.addResource(new Path(coreConfPath));
- config.addResource(new Path(hdfsConfPath));
-
- FileSystem fs = FileSystem.get(config);
- Path defaultDir = new Path(fs.getHomeDirectory(),".samoa");
- return defaultDir.toString();
- }
-
- static boolean deleteFileIfExist(String absPath) {
- Path p = new Path(absPath);
- return deleteFileIfExist(p);
- }
-
- static boolean deleteFileIfExist(Path p) {
- Configuration config = new Configuration();
- config.addResource(new Path(coreConfPath));
- config.addResource(new Path(hdfsConfPath));
-
- FileSystem fs;
- try {
- fs = FileSystem.get(config);
- if (fs.exists(p)) {
- return fs.delete(p, false);
- }
- else
- return true;
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- return false;
- }
-
- /*
- * Write to HDFS
- */
- static String writeToHDFS(File file, String dstPath) {
- Configuration config = new Configuration();
- config.addResource(new Path(coreConfPath));
- config.addResource(new Path(hdfsConfPath));
- logger.info("Filesystem name:{}",config.get("fs.defaultFS"));
-
- // Default samoaDir
- if (samoaDir == null) {
- try {
- samoaDir = getDefaultSAMOADir();
- }
- catch (IOException e) {
- e.printStackTrace();
- return null;
- }
- }
-
- // Setup src and dst paths
- //java.nio.file.Path tempPath = FileSystems.getDefault().getPath(samoaDir, dstPath);
- Path dst = new Path(samoaDir,dstPath);
- Path src = new Path(file.getAbsolutePath());
-
- // Delete file if already exists in HDFS
- if (deleteFileIfExist(dst) == false)
- return null;
-
- // Copy to HDFS
- FileSystem fs;
- try {
- fs = FileSystem.get(config);
- fs.copyFromLocalFile(src, dst);
- } catch (IOException e) {
- e.printStackTrace();
- return null;
- }
-
- return dst.toString(); // abs path to file
- }
-
- /*
- * Read from HDFS
- */
- static Object deserializeObjectFromFile(String filePath) {
- logger.info("Deserialize HDFS file:{}",filePath);
- Configuration config = new Configuration();
- config.addResource(new Path(coreConfPath));
- config.addResource(new Path(hdfsConfPath));
-
- Path file = new Path(filePath);
- FSDataInputStream dataInputStream = null;
- ObjectInputStream ois = null;
- Object obj = null;
- FileSystem fs;
- try {
- fs = FileSystem.get(config);
- dataInputStream = fs.open(file);
- ois = new ObjectInputStream(dataInputStream);
- obj = ois.readObject();
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- } catch (ClassNotFoundException e) {
- try {
- if (dataInputStream != null) dataInputStream.close();
- if (ois != null) ois.close();
- } catch (IOException ioException) {
- // TODO auto-generated catch block
- e.printStackTrace();
- }
- }
-
- return obj;
- }
-
- }
-
- private static class LocalFileSystemUtils {
- static boolean serializObjectToFile(Object obj, String fn) {
- FileOutputStream fos = null;
- ObjectOutputStream oos = null;
- try {
- fos = new FileOutputStream(fn);
- oos = new ObjectOutputStream(fos);
- oos.writeObject(obj);
- } catch (FileNotFoundException e) {
- e.printStackTrace();
- return false;
- } catch (IOException e) {
- e.printStackTrace();
- return false;
- } finally {
- try {
- if (fos != null) fos.close();
- if (oos != null) oos.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
+ private static final String TEMP_FILE = "samoaTemp";
+ private static final String TEMP_FILE_SUFFIX = ".dat";
- return true;
- }
-
- static Object deserializeObjectFromLocalFile(String filename) {
- logger.info("Deserialize local file:{}",filename);
- FileInputStream fis = null;
- ObjectInputStream ois = null;
- Object obj = null;
- try {
- fis = new FileInputStream(filename);
- ois = new ObjectInputStream(fis);
- obj = ois.readObject();
- } catch (IOException e) {
- // TODO auto-generated catch block
- e.printStackTrace();
- } catch (ClassNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- } finally {
- try {
- if (fis != null) fis.close();
- if (ois != null) ois.close();
- } catch (IOException e) {
- // TODO auto-generated catch block
- e.printStackTrace();
- }
- }
+ /*
+ * Kafka
+ */
+ private static class KafkaUtils {
+ private static ZkClient zkClient;
- return obj;
- }
- }
-
-
-
- /*
- * Create streams
- */
- public static void createKafkaTopic(String name, int partitions) {
- createKafkaTopic(name, partitions, 1);
- }
-
- public static void createKafkaTopic(String name, int partitions, int replicas) {
- KafkaUtils.createKafkaTopic(name, partitions, replicas);
- }
-
- /*
- * Serialize object
- */
- public static boolean serializeObjectToLocalFileSystem(Object object, String path) {
- return LocalFileSystemUtils.serializObjectToFile(object, path);
- }
-
- public static String serializeObjectToHDFS(Object object, String path) {
- File tmpDatFile;
- try {
- tmpDatFile = File.createTempFile(TEMP_FILE, TEMP_FILE_SUFFIX);
- if (serializeObjectToLocalFileSystem(object, tmpDatFile.getAbsolutePath())) {
- return HDFSUtils.writeToHDFS(tmpDatFile, path);
- }
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- return null;
- }
-
- /*
- * Deserialize object
- */
- @SuppressWarnings("unchecked")
- public static Map<String,Object> deserializeMapFromFile(String filesystem, String filename) {
- Map<String,Object> map;
- if (filesystem.equals(HDFS)) {
- map = (Map<String,Object>) HDFSUtils.deserializeObjectFromFile(filename);
- }
- else {
- map = (Map<String,Object>) LocalFileSystemUtils.deserializeObjectFromLocalFile(filename);
- }
- return map;
- }
-
- public static Object deserializeObjectFromFileAndKey(String filesystem, String filename, String key) {
- Map<String,Object> map = deserializeMapFromFile(filesystem, filename);
- if (map == null) return null;
- return map.get(key);
- }
-
- /*
- * Setup
- */
- public static void setZookeeper(String zookeeper) {
- KafkaUtils.setZookeeper(zookeeper);
- }
-
- public static void setHadoopConfigHome(String hadoopHome) {
- HDFSUtils.setHadoopConfigHome(hadoopHome);
- }
-
- public static void setSAMOADir(String samoaDir) {
- HDFSUtils.setSAMOADir(samoaDir);
- }
-
- /*
- * Others
- */
- public static String getHDFSNameNodeUri() {
- return HDFSUtils.getNameNodeUri();
- }
- public static String getHadoopConfigHome() {
- return HDFSUtils.getHadoopConfigHome();
- }
-
- public static String copyToHDFS(File file, String dstPath) {
- return HDFSUtils.writeToHDFS(file, dstPath);
- }
+ static void setZookeeper(String zk) {
+ zkClient = new ZkClient(zk, 30000, 30000, new ZKStringSerializerWrapper());
+ }
+
+ /*
+ * Create Kafka topic/stream
+ */
+ static void createKafkaTopic(String name, int partitions, int replicas) {
+ AdminUtils.createTopic(zkClient, name, partitions, replicas, new Properties());
+ }
+
+ static class ZKStringSerializerWrapper implements ZkSerializer {
+ @Override
+ public Object deserialize(byte[] byteArray) throws ZkMarshallingError {
+ return ZKStringSerializer.deserialize(byteArray);
+ }
+
+ @Override
+ public byte[] serialize(Object obj) throws ZkMarshallingError {
+ return ZKStringSerializer.serialize(obj);
+ }
+ }
+ }
+
+ /*
+ * HDFS
+ */
+ private static class HDFSUtils {
+ private static String coreConfPath;
+ private static String hdfsConfPath;
+ private static String configHomePath;
+ private static String samoaDir = null;
+
+ static void setHadoopConfigHome(String hadoopConfPath) {
+ logger.info("Hadoop config home:{}", hadoopConfPath);
+ configHomePath = hadoopConfPath;
+ java.nio.file.Path coreSitePath = FileSystems.getDefault().getPath(hadoopConfPath, "core-site.xml");
+ java.nio.file.Path hdfsSitePath = FileSystems.getDefault().getPath(hadoopConfPath, "hdfs-site.xml");
+ coreConfPath = coreSitePath.toAbsolutePath().toString();
+ hdfsConfPath = hdfsSitePath.toAbsolutePath().toString();
+ }
+
+ static String getNameNodeUri() {
+ Configuration config = new Configuration();
+ config.addResource(new Path(coreConfPath));
+ config.addResource(new Path(hdfsConfPath));
+
+ return config.get("fs.defaultFS");
+ }
+
+ static String getHadoopConfigHome() {
+ return configHomePath;
+ }
+
+ static void setSAMOADir(String dir) {
+ if (dir != null)
+ samoaDir = getNameNodeUri() + dir;
+ else
+ samoaDir = null;
+ }
+
+ static String getDefaultSAMOADir() throws IOException {
+ Configuration config = new Configuration();
+ config.addResource(new Path(coreConfPath));
+ config.addResource(new Path(hdfsConfPath));
+
+ FileSystem fs = FileSystem.get(config);
+ Path defaultDir = new Path(fs.getHomeDirectory(), ".samoa");
+ return defaultDir.toString();
+ }
+
+ static boolean deleteFileIfExist(String absPath) {
+ Path p = new Path(absPath);
+ return deleteFileIfExist(p);
+ }
+
+ static boolean deleteFileIfExist(Path p) {
+ Configuration config = new Configuration();
+ config.addResource(new Path(coreConfPath));
+ config.addResource(new Path(hdfsConfPath));
+
+ FileSystem fs;
+ try {
+ fs = FileSystem.get(config);
+ if (fs.exists(p)) {
+ return fs.delete(p, false);
+ }
+ else
+ return true;
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ return false;
+ }
+
+ /*
+ * Write to HDFS
+ */
+ static String writeToHDFS(File file, String dstPath) {
+ Configuration config = new Configuration();
+ config.addResource(new Path(coreConfPath));
+ config.addResource(new Path(hdfsConfPath));
+ logger.info("Filesystem name:{}", config.get("fs.defaultFS"));
+
+ // Default samoaDir
+ if (samoaDir == null) {
+ try {
+ samoaDir = getDefaultSAMOADir();
+ } catch (IOException e) {
+ e.printStackTrace();
+ return null;
+ }
+ }
+
+ // Setup src and dst paths
+ // java.nio.file.Path tempPath =
+ // FileSystems.getDefault().getPath(samoaDir, dstPath);
+ Path dst = new Path(samoaDir, dstPath);
+ Path src = new Path(file.getAbsolutePath());
+
+ // Delete file if already exists in HDFS
+ if (deleteFileIfExist(dst) == false)
+ return null;
+
+ // Copy to HDFS
+ FileSystem fs;
+ try {
+ fs = FileSystem.get(config);
+ fs.copyFromLocalFile(src, dst);
+ } catch (IOException e) {
+ e.printStackTrace();
+ return null;
+ }
+
+ return dst.toString(); // abs path to file
+ }
+
+ /*
+ * Read from HDFS
+ */
+ static Object deserializeObjectFromFile(String filePath) {
+ logger.info("Deserialize HDFS file:{}", filePath);
+ Configuration config = new Configuration();
+ config.addResource(new Path(coreConfPath));
+ config.addResource(new Path(hdfsConfPath));
+
+ Path file = new Path(filePath);
+ FSDataInputStream dataInputStream = null;
+ ObjectInputStream ois = null;
+ Object obj = null;
+ FileSystem fs;
+ try {
+ fs = FileSystem.get(config);
+ dataInputStream = fs.open(file);
+ ois = new ObjectInputStream(dataInputStream);
+ obj = ois.readObject();
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (ClassNotFoundException e) {
+ try {
+ if (dataInputStream != null)
+ dataInputStream.close();
+ if (ois != null)
+ ois.close();
+ } catch (IOException ioException) {
+ // TODO auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+
+ return obj;
+ }
+
+ }
+
+ private static class LocalFileSystemUtils {
+ static boolean serializObjectToFile(Object obj, String fn) {
+ FileOutputStream fos = null;
+ ObjectOutputStream oos = null;
+ try {
+ fos = new FileOutputStream(fn);
+ oos = new ObjectOutputStream(fos);
+ oos.writeObject(obj);
+ } catch (FileNotFoundException e) {
+ e.printStackTrace();
+ return false;
+ } catch (IOException e) {
+ e.printStackTrace();
+ return false;
+ } finally {
+ try {
+ if (fos != null)
+ fos.close();
+ if (oos != null)
+ oos.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ return true;
+ }
+
+ static Object deserializeObjectFromLocalFile(String filename) {
+ logger.info("Deserialize local file:{}", filename);
+ FileInputStream fis = null;
+ ObjectInputStream ois = null;
+ Object obj = null;
+ try {
+ fis = new FileInputStream(filename);
+ ois = new ObjectInputStream(fis);
+ obj = ois.readObject();
+ } catch (IOException e) {
+ // TODO auto-generated catch block
+ e.printStackTrace();
+ } catch (ClassNotFoundException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } finally {
+ try {
+ if (fis != null)
+ fis.close();
+ if (ois != null)
+ ois.close();
+ } catch (IOException e) {
+ // TODO auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+
+ return obj;
+ }
+ }
+
+ /*
+ * Create streams
+ */
+ public static void createKafkaTopic(String name, int partitions) {
+ createKafkaTopic(name, partitions, 1);
+ }
+
+ public static void createKafkaTopic(String name, int partitions, int replicas) {
+ KafkaUtils.createKafkaTopic(name, partitions, replicas);
+ }
+
+ /*
+ * Serialize object
+ */
+ public static boolean serializeObjectToLocalFileSystem(Object object, String path) {
+ return LocalFileSystemUtils.serializObjectToFile(object, path);
+ }
+
+ public static String serializeObjectToHDFS(Object object, String path) {
+ File tmpDatFile;
+ try {
+ tmpDatFile = File.createTempFile(TEMP_FILE, TEMP_FILE_SUFFIX);
+ if (serializeObjectToLocalFileSystem(object, tmpDatFile.getAbsolutePath())) {
+ return HDFSUtils.writeToHDFS(tmpDatFile, path);
+ }
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ return null;
+ }
+
+ /*
+ * Deserialize object
+ */
+ @SuppressWarnings("unchecked")
+ public static Map<String, Object> deserializeMapFromFile(String filesystem, String filename) {
+ Map<String, Object> map;
+ if (filesystem.equals(HDFS)) {
+ map = (Map<String, Object>) HDFSUtils.deserializeObjectFromFile(filename);
+ }
+ else {
+ map = (Map<String, Object>) LocalFileSystemUtils.deserializeObjectFromLocalFile(filename);
+ }
+ return map;
+ }
+
+ public static Object deserializeObjectFromFileAndKey(String filesystem, String filename, String key) {
+ Map<String, Object> map = deserializeMapFromFile(filesystem, filename);
+ if (map == null)
+ return null;
+ return map.get(key);
+ }
+
+ /*
+ * Setup
+ */
+ public static void setZookeeper(String zookeeper) {
+ KafkaUtils.setZookeeper(zookeeper);
+ }
+
+ public static void setHadoopConfigHome(String hadoopHome) {
+ HDFSUtils.setHadoopConfigHome(hadoopHome);
+ }
+
+ public static void setSAMOADir(String samoaDir) {
+ HDFSUtils.setSAMOADir(samoaDir);
+ }
+
+ /*
+ * Others
+ */
+ public static String getHDFSNameNodeUri() {
+ return HDFSUtils.getNameNodeUri();
+ }
+
+ public static String getHadoopConfigHome() {
+ return HDFSUtils.getHadoopConfigHome();
+ }
+
+ public static String copyToHDFS(File file, String dstPath) {
+ return HDFSUtils.writeToHDFS(file, dstPath);
+ }
}
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/LocalStormDoTask.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/LocalStormDoTask.java
index 54792ae..d6ea26e 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/LocalStormDoTask.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/LocalStormDoTask.java
@@ -34,45 +34,46 @@
/**
* The main class to execute a SAMOA task in LOCAL mode in Storm.
- *
+ *
* @author Arinto Murdopo
- *
+ *
*/
public class LocalStormDoTask {
- private static final Logger logger = LoggerFactory.getLogger(LocalStormDoTask.class);
+ private static final Logger logger = LoggerFactory.getLogger(LocalStormDoTask.class);
- /**
- * The main method.
- *
- * @param args the arguments
- */
- public static void main(String[] args) {
+ /**
+ * The main method.
+ *
+ * @param args
+ * the arguments
+ */
+ public static void main(String[] args) {
- List<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
+ List<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
- int numWorker = StormSamoaUtils.numWorkers(tmpArgs);
+ int numWorker = StormSamoaUtils.numWorkers(tmpArgs);
- args = tmpArgs.toArray(new String[0]);
+ args = tmpArgs.toArray(new String[0]);
- //convert the arguments into Storm topology
- StormTopology stormTopo = StormSamoaUtils.argsToTopology(args);
- String topologyName = stormTopo.getTopologyName();
+ // convert the arguments into Storm topology
+ StormTopology stormTopo = StormSamoaUtils.argsToTopology(args);
+ String topologyName = stormTopo.getTopologyName();
- Config conf = new Config();
- //conf.putAll(Utils.readStormConfig());
- conf.setDebug(false);
+ Config conf = new Config();
+ // conf.putAll(Utils.readStormConfig());
+ conf.setDebug(false);
- //local mode
- conf.setMaxTaskParallelism(numWorker);
+ // local mode
+ conf.setMaxTaskParallelism(numWorker);
- backtype.storm.LocalCluster cluster = new backtype.storm.LocalCluster();
- cluster.submitTopology(topologyName, conf, stormTopo.getStormBuilder().createTopology());
+ backtype.storm.LocalCluster cluster = new backtype.storm.LocalCluster();
+ cluster.submitTopology(topologyName, conf, stormTopo.getStormBuilder().createTopology());
- backtype.storm.utils.Utils.sleep(600 * 1000);
+ backtype.storm.utils.Utils.sleep(600 * 1000);
- cluster.killTopology(topologyName);
- cluster.shutdown();
+ cluster.killTopology(topologyName);
+ cluster.shutdown();
- }
+ }
}
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormBoltStream.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormBoltStream.java
index ad9794d..84a9336 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormBoltStream.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormBoltStream.java
@@ -26,40 +26,41 @@
/**
* Storm Stream that connects into Bolt. It wraps Storm's outputCollector class
+ *
* @author Arinto Murdopo
- *
+ *
*/
-class StormBoltStream extends StormStream{
-
- /**
+class StormBoltStream extends StormStream {
+
+ /**
*
*/
- private static final long serialVersionUID = -5712513402991550847L;
-
- private OutputCollector outputCollector;
+ private static final long serialVersionUID = -5712513402991550847L;
- StormBoltStream(String stormComponentId){
- super(stormComponentId);
- }
+ private OutputCollector outputCollector;
- @Override
- public void put(ContentEvent contentEvent) {
- outputCollector.emit(this.outputStreamId, new Values(contentEvent, contentEvent.getKey()));
- }
-
- public void setCollector(OutputCollector outputCollector){
- this.outputCollector = outputCollector;
- }
+ StormBoltStream(String stormComponentId) {
+ super(stormComponentId);
+ }
-// @Override
-// public void setStreamId(String streamId) {
-// // TODO Auto-generated method stub
-// //this.outputStreamId = streamId;
-// }
+ @Override
+ public void put(ContentEvent contentEvent) {
+ outputCollector.emit(this.outputStreamId, new Values(contentEvent, contentEvent.getKey()));
+ }
- @Override
- public String getStreamId() {
- // TODO Auto-generated method stub
- return null;
- }
+ public void setCollector(OutputCollector outputCollector) {
+ this.outputCollector = outputCollector;
+ }
+
+ // @Override
+ // public void setStreamId(String streamId) {
+ // // TODO Auto-generated method stub
+ // //this.outputStreamId = streamId;
+ // }
+
+ @Override
+ public String getStreamId() {
+ // TODO Auto-generated method stub
+ return null;
+ }
}
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormComponentFactory.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormComponentFactory.java
index 347fd50..9a2bc65 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormComponentFactory.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormComponentFactory.java
@@ -37,54 +37,54 @@
*/
public final class StormComponentFactory implements ComponentFactory {
- private final Map<String, Integer> processorList;
+ private final Map<String, Integer> processorList;
- public StormComponentFactory() {
- processorList = new HashMap<>();
+ public StormComponentFactory() {
+ processorList = new HashMap<>();
+ }
+
+ @Override
+ public ProcessingItem createPi(Processor processor) {
+ return new StormProcessingItem(processor, this.getComponentName(processor.getClass()), 1);
+ }
+
+ @Override
+ public EntranceProcessingItem createEntrancePi(EntranceProcessor processor) {
+ return new StormEntranceProcessingItem(processor, this.getComponentName(processor.getClass()));
+ }
+
+ @Override
+ public Stream createStream(IProcessingItem sourcePi) {
+ StormTopologyNode stormCompatiblePi = (StormTopologyNode) sourcePi;
+ return stormCompatiblePi.createStream();
+ }
+
+ @Override
+ public Topology createTopology(String topoName) {
+ return new StormTopology(topoName);
+ }
+
+ private String getComponentName(Class<? extends Processor> clazz) {
+ StringBuilder componentName = new StringBuilder(clazz.getCanonicalName());
+ String key = componentName.toString();
+ Integer index;
+
+ if (!processorList.containsKey(key)) {
+ index = 1;
+ } else {
+ index = processorList.get(key) + 1;
}
- @Override
- public ProcessingItem createPi(Processor processor) {
- return new StormProcessingItem(processor, this.getComponentName(processor.getClass()), 1);
- }
+ processorList.put(key, index);
- @Override
- public EntranceProcessingItem createEntrancePi(EntranceProcessor processor) {
- return new StormEntranceProcessingItem(processor, this.getComponentName(processor.getClass()));
- }
+ componentName.append('_');
+ componentName.append(index);
- @Override
- public Stream createStream(IProcessingItem sourcePi) {
- StormTopologyNode stormCompatiblePi = (StormTopologyNode) sourcePi;
- return stormCompatiblePi.createStream();
- }
+ return componentName.toString();
+ }
- @Override
- public Topology createTopology(String topoName) {
- return new StormTopology(topoName);
- }
-
- private String getComponentName(Class<? extends Processor> clazz) {
- StringBuilder componentName = new StringBuilder(clazz.getCanonicalName());
- String key = componentName.toString();
- Integer index;
-
- if (!processorList.containsKey(key)) {
- index = 1;
- } else {
- index = processorList.get(key) + 1;
- }
-
- processorList.put(key, index);
-
- componentName.append('_');
- componentName.append(index);
-
- return componentName.toString();
- }
-
- @Override
- public ProcessingItem createPi(Processor processor, int parallelism) {
- return new StormProcessingItem(processor, this.getComponentName(processor.getClass()), parallelism);
- }
+ @Override
+ public ProcessingItem createPi(Processor processor, int parallelism) {
+ return new StormProcessingItem(processor, this.getComponentName(processor.getClass()), parallelism);
+ }
}
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormDoTask.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormDoTask.java
index fc0630a..758ae1e 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormDoTask.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormDoTask.java
@@ -31,87 +31,88 @@
import backtype.storm.utils.Utils;
/**
- * The main class that used by samoa script to execute SAMOA task.
+ * The main class that used by samoa script to execute SAMOA task.
*
* @author Arinto Murdopo
- *
+ *
*/
public class StormDoTask {
- private static final Logger logger = LoggerFactory.getLogger(StormDoTask.class);
- private static String localFlag = "local";
- private static String clusterFlag = "cluster";
-
- /**
- * The main method.
- *
- * @param args the arguments
- */
- public static void main(String[] args) {
+ private static final Logger logger = LoggerFactory.getLogger(StormDoTask.class);
+ private static String localFlag = "local";
+ private static String clusterFlag = "cluster";
- List<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
-
- boolean isLocal = isLocal(tmpArgs);
- int numWorker = StormSamoaUtils.numWorkers(tmpArgs);
-
- args = tmpArgs.toArray(new String[0]);
-
- //convert the arguments into Storm topology
- StormTopology stormTopo = StormSamoaUtils.argsToTopology(args);
- String topologyName = stormTopo.getTopologyName();
-
- Config conf = new Config();
- conf.putAll(Utils.readStormConfig());
- conf.setDebug(false);
-
-
- if(isLocal){
- //local mode
- conf.setMaxTaskParallelism(numWorker);
-
- backtype.storm.LocalCluster cluster = new backtype.storm.LocalCluster();
- cluster.submitTopology(topologyName , conf, stormTopo.getStormBuilder().createTopology());
-
- backtype.storm.utils.Utils.sleep(600*1000);
-
- cluster.killTopology(topologyName);
- cluster.shutdown();
-
- }else{
- //cluster mode
- conf.setNumWorkers(numWorker);
- try {
- backtype.storm.StormSubmitter.submitTopology(topologyName, conf,
- stormTopo.getStormBuilder().createTopology());
- } catch (backtype.storm.generated.AlreadyAliveException ale) {
- ale.printStackTrace();
- } catch (backtype.storm.generated.InvalidTopologyException ite) {
- ite.printStackTrace();
- }
- }
- }
-
- private static boolean isLocal(List<String> tmpArgs){
- ExecutionMode executionMode = ExecutionMode.UNDETERMINED;
-
- int position = tmpArgs.size() - 1;
- String flag = tmpArgs.get(position);
- boolean isLocal = true;
-
- if(flag.equals(clusterFlag)){
- executionMode = ExecutionMode.CLUSTER;
- isLocal = false;
- }else if(flag.equals(localFlag)){
- executionMode = ExecutionMode.LOCAL;
- isLocal = true;
- }
-
- if(executionMode != ExecutionMode.UNDETERMINED){
- tmpArgs.remove(position);
- }
-
- return isLocal;
- }
-
- private enum ExecutionMode {LOCAL, CLUSTER, UNDETERMINED};
+ /**
+ * The main method.
+ *
+ * @param args
+ * the arguments
+ */
+ public static void main(String[] args) {
+
+ List<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
+
+ boolean isLocal = isLocal(tmpArgs);
+ int numWorker = StormSamoaUtils.numWorkers(tmpArgs);
+
+ args = tmpArgs.toArray(new String[0]);
+
+ // convert the arguments into Storm topology
+ StormTopology stormTopo = StormSamoaUtils.argsToTopology(args);
+ String topologyName = stormTopo.getTopologyName();
+
+ Config conf = new Config();
+ conf.putAll(Utils.readStormConfig());
+ conf.setDebug(false);
+
+ if (isLocal) {
+ // local mode
+ conf.setMaxTaskParallelism(numWorker);
+
+ backtype.storm.LocalCluster cluster = new backtype.storm.LocalCluster();
+ cluster.submitTopology(topologyName, conf, stormTopo.getStormBuilder().createTopology());
+
+ backtype.storm.utils.Utils.sleep(600 * 1000);
+
+ cluster.killTopology(topologyName);
+ cluster.shutdown();
+
+ } else {
+ // cluster mode
+ conf.setNumWorkers(numWorker);
+ try {
+ backtype.storm.StormSubmitter.submitTopology(topologyName, conf,
+ stormTopo.getStormBuilder().createTopology());
+ } catch (backtype.storm.generated.AlreadyAliveException ale) {
+ ale.printStackTrace();
+ } catch (backtype.storm.generated.InvalidTopologyException ite) {
+ ite.printStackTrace();
+ }
+ }
+ }
+
+ private static boolean isLocal(List<String> tmpArgs) {
+ ExecutionMode executionMode = ExecutionMode.UNDETERMINED;
+
+ int position = tmpArgs.size() - 1;
+ String flag = tmpArgs.get(position);
+ boolean isLocal = true;
+
+ if (flag.equals(clusterFlag)) {
+ executionMode = ExecutionMode.CLUSTER;
+ isLocal = false;
+ } else if (flag.equals(localFlag)) {
+ executionMode = ExecutionMode.LOCAL;
+ isLocal = true;
+ }
+
+ if (executionMode != ExecutionMode.UNDETERMINED) {
+ tmpArgs.remove(position);
+ }
+
+ return isLocal;
+ }
+
+ private enum ExecutionMode {
+ LOCAL, CLUSTER, UNDETERMINED
+ };
}
-
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormEntranceProcessingItem.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormEntranceProcessingItem.java
index d4d80bf..832ee34 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormEntranceProcessingItem.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormEntranceProcessingItem.java
@@ -41,168 +41,172 @@
* EntranceProcessingItem implementation for Storm.
*/
class StormEntranceProcessingItem extends AbstractEntranceProcessingItem implements StormTopologyNode {
- private final StormEntranceSpout piSpout;
+ private final StormEntranceSpout piSpout;
- StormEntranceProcessingItem(EntranceProcessor processor) {
- this(processor, UUID.randomUUID().toString());
+ StormEntranceProcessingItem(EntranceProcessor processor) {
+ this(processor, UUID.randomUUID().toString());
+ }
+
+ StormEntranceProcessingItem(EntranceProcessor processor, String friendlyId) {
+ super(processor);
+ this.setName(friendlyId);
+ this.piSpout = new StormEntranceSpout(processor);
+ }
+
+ @Override
+ public EntranceProcessingItem setOutputStream(Stream stream) {
+ // piSpout.streams.add(stream);
+ piSpout.setOutputStream((StormStream) stream);
+ return this;
+ }
+
+ @Override
+ public Stream getOutputStream() {
+ return piSpout.getOutputStream();
+ }
+
+ @Override
+ public void addToTopology(StormTopology topology, int parallelismHint) {
+ topology.getStormBuilder().setSpout(this.getName(), piSpout, parallelismHint);
+ }
+
+ @Override
+ public StormStream createStream() {
+ return piSpout.createStream(this.getName());
+ }
+
+ @Override
+ public String getId() {
+ return this.getName();
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder(super.toString());
+ sb.insert(0, String.format("id: %s, ", this.getName()));
+ return sb.toString();
+ }
+
+ /**
+ * Resulting Spout of StormEntranceProcessingItem
+ */
+ final static class StormEntranceSpout extends BaseRichSpout {
+
+ private static final long serialVersionUID = -9066409791668954099L;
+
+ // private final Set<StormSpoutStream> streams;
+ private final EntranceProcessor entranceProcessor;
+ private StormStream outputStream;
+
+ // private transient SpoutStarter spoutStarter;
+ // private transient Executor spoutExecutors;
+ // private transient LinkedBlockingQueue<StormTupleInfo> tupleInfoQueue;
+
+ private SpoutOutputCollector collector;
+
+ StormEntranceSpout(EntranceProcessor processor) {
+ // this.streams = new HashSet<StormSpoutStream>();
+ this.entranceProcessor = processor;
}
- StormEntranceProcessingItem(EntranceProcessor processor, String friendlyId) {
- super(processor);
- this.setName(friendlyId);
- this.piSpout = new StormEntranceSpout(processor);
+ public StormStream getOutputStream() {
+ return outputStream;
+ }
+
+ public void setOutputStream(StormStream stream) {
+ this.outputStream = stream;
}
@Override
- public EntranceProcessingItem setOutputStream(Stream stream) {
- // piSpout.streams.add(stream);
- piSpout.setOutputStream((StormStream) stream);
- return this;
- }
-
- @Override
- public Stream getOutputStream() {
- return piSpout.getOutputStream();
- }
-
- @Override
- public void addToTopology(StormTopology topology, int parallelismHint) {
- topology.getStormBuilder().setSpout(this.getName(), piSpout, parallelismHint);
+ public void open(@SuppressWarnings("rawtypes") Map conf, TopologyContext context, SpoutOutputCollector collector) {
+ this.collector = collector;
+ // this.tupleInfoQueue = new LinkedBlockingQueue<StormTupleInfo>();
+
+ // Processor and this class share the same instance of stream
+ // for (StormSpoutStream stream : streams) {
+ // stream.setSpout(this);
+ // }
+ // outputStream.setSpout(this);
+
+ this.entranceProcessor.onCreate(context.getThisTaskId());
+ // this.spoutStarter = new SpoutStarter(this.starter);
+
+ // this.spoutExecutors = Executors.newSingleThreadExecutor();
+ // this.spoutExecutors.execute(spoutStarter);
}
@Override
- public StormStream createStream() {
- return piSpout.createStream(this.getName());
+ public void nextTuple() {
+ if (entranceProcessor.hasNext()) {
+ Values value = newValues(entranceProcessor.nextEvent());
+ collector.emit(outputStream.getOutputId(), value);
+ } else
+ Utils.sleep(1000);
+ // StormTupleInfo tupleInfo = tupleInfoQueue.poll(50,
+ // TimeUnit.MILLISECONDS);
+ // if (tupleInfo != null) {
+ // Values value = newValues(tupleInfo.getContentEvent());
+ // collector.emit(tupleInfo.getStormStream().getOutputId(), value);
+ // }
}
@Override
- public String getId() {
- return this.getName();
+ public void declareOutputFields(OutputFieldsDeclarer declarer) {
+ // for (StormStream stream : streams) {
+ // declarer.declareStream(stream.getOutputId(), new
+ // Fields(StormSamoaUtils.CONTENT_EVENT_FIELD,
+ // StormSamoaUtils.KEY_FIELD));
+ // }
+ declarer.declareStream(outputStream.getOutputId(), new Fields(StormSamoaUtils.CONTENT_EVENT_FIELD,
+ StormSamoaUtils.KEY_FIELD));
}
- @Override
- public String toString() {
- StringBuilder sb = new StringBuilder(super.toString());
- sb.insert(0, String.format("id: %s, ", this.getName()));
- return sb.toString();
+ StormStream createStream(String piId) {
+ // StormSpoutStream stream = new StormSpoutStream(piId);
+ StormStream stream = new StormBoltStream(piId);
+ // streams.add(stream);
+ return stream;
}
- /**
- * Resulting Spout of StormEntranceProcessingItem
- */
- final static class StormEntranceSpout extends BaseRichSpout {
+ // void put(StormSpoutStream stream, ContentEvent contentEvent) {
+ // tupleInfoQueue.add(new StormTupleInfo(stream, contentEvent));
+ // }
- private static final long serialVersionUID = -9066409791668954099L;
-
- // private final Set<StormSpoutStream> streams;
- private final EntranceProcessor entranceProcessor;
- private StormStream outputStream;
-
- // private transient SpoutStarter spoutStarter;
- // private transient Executor spoutExecutors;
- // private transient LinkedBlockingQueue<StormTupleInfo> tupleInfoQueue;
-
- private SpoutOutputCollector collector;
-
- StormEntranceSpout(EntranceProcessor processor) {
- // this.streams = new HashSet<StormSpoutStream>();
- this.entranceProcessor = processor;
- }
-
- public StormStream getOutputStream() {
- return outputStream;
- }
-
- public void setOutputStream(StormStream stream) {
- this.outputStream = stream;
- }
-
- @Override
- public void open(@SuppressWarnings("rawtypes") Map conf, TopologyContext context, SpoutOutputCollector collector) {
- this.collector = collector;
- // this.tupleInfoQueue = new LinkedBlockingQueue<StormTupleInfo>();
-
- // Processor and this class share the same instance of stream
- // for (StormSpoutStream stream : streams) {
- // stream.setSpout(this);
- // }
- // outputStream.setSpout(this);
-
- this.entranceProcessor.onCreate(context.getThisTaskId());
- // this.spoutStarter = new SpoutStarter(this.starter);
-
- // this.spoutExecutors = Executors.newSingleThreadExecutor();
- // this.spoutExecutors.execute(spoutStarter);
- }
-
- @Override
- public void nextTuple() {
- if (entranceProcessor.hasNext()) {
- Values value = newValues(entranceProcessor.nextEvent());
- collector.emit(outputStream.getOutputId(), value);
- } else
- Utils.sleep(1000);
- // StormTupleInfo tupleInfo = tupleInfoQueue.poll(50, TimeUnit.MILLISECONDS);
- // if (tupleInfo != null) {
- // Values value = newValues(tupleInfo.getContentEvent());
- // collector.emit(tupleInfo.getStormStream().getOutputId(), value);
- // }
- }
-
- @Override
- public void declareOutputFields(OutputFieldsDeclarer declarer) {
- // for (StormStream stream : streams) {
- // declarer.declareStream(stream.getOutputId(), new Fields(StormSamoaUtils.CONTENT_EVENT_FIELD, StormSamoaUtils.KEY_FIELD));
- // }
- declarer.declareStream(outputStream.getOutputId(), new Fields(StormSamoaUtils.CONTENT_EVENT_FIELD, StormSamoaUtils.KEY_FIELD));
- }
-
- StormStream createStream(String piId) {
- // StormSpoutStream stream = new StormSpoutStream(piId);
- StormStream stream = new StormBoltStream(piId);
- // streams.add(stream);
- return stream;
- }
-
- // void put(StormSpoutStream stream, ContentEvent contentEvent) {
- // tupleInfoQueue.add(new StormTupleInfo(stream, contentEvent));
- // }
-
- private Values newValues(ContentEvent contentEvent) {
- return new Values(contentEvent, contentEvent.getKey());
- }
-
- // private final static class StormTupleInfo {
- //
- // private final StormStream stream;
- // private final ContentEvent event;
- //
- // StormTupleInfo(StormStream stream, ContentEvent event) {
- // this.stream = stream;
- // this.event = event;
- // }
- //
- // public StormStream getStormStream() {
- // return this.stream;
- // }
- //
- // public ContentEvent getContentEvent() {
- // return this.event;
- // }
- // }
-
- // private final static class SpoutStarter implements Runnable {
- //
- // private final TopologyStarter topoStarter;
- //
- // SpoutStarter(TopologyStarter topoStarter) {
- // this.topoStarter = topoStarter;
- // }
- //
- // @Override
- // public void run() {
- // this.topoStarter.start();
- // }
- // }
+ private Values newValues(ContentEvent contentEvent) {
+ return new Values(contentEvent, contentEvent.getKey());
}
+
+ // private final static class StormTupleInfo {
+ //
+ // private final StormStream stream;
+ // private final ContentEvent event;
+ //
+ // StormTupleInfo(StormStream stream, ContentEvent event) {
+ // this.stream = stream;
+ // this.event = event;
+ // }
+ //
+ // public StormStream getStormStream() {
+ // return this.stream;
+ // }
+ //
+ // public ContentEvent getContentEvent() {
+ // return this.event;
+ // }
+ // }
+
+ // private final static class SpoutStarter implements Runnable {
+ //
+ // private final TopologyStarter topoStarter;
+ //
+ // SpoutStarter(TopologyStarter topoStarter) {
+ // this.topoStarter = topoStarter;
+ // }
+ //
+ // @Override
+ // public void run() {
+ // this.topoStarter.start();
+ // }
+ // }
+ }
}
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormJarSubmitter.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormJarSubmitter.java
index 5f86855..6594aa7 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormJarSubmitter.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormJarSubmitter.java
@@ -34,42 +34,42 @@
* Utility class to submit samoa-storm jar to a Storm cluster.
*
* @author Arinto Murdopo
- *
+ *
*/
public class StormJarSubmitter {
-
- public final static String UPLOADED_JAR_LOCATION_KEY = "UploadedJarLocation";
- /**
- * @param args
- * @throws IOException
- */
- public static void main(String[] args) throws IOException {
-
- Config config = new Config();
- config.putAll(Utils.readCommandLineOpts());
- config.putAll(Utils.readStormConfig());
+ public final static String UPLOADED_JAR_LOCATION_KEY = "UploadedJarLocation";
- String nimbusHost = (String) config.get(Config.NIMBUS_HOST);
- int nimbusThriftPort = Utils.getInt(config
- .get(Config.NIMBUS_THRIFT_PORT));
+ /**
+ * @param args
+ * @throws IOException
+ */
+ public static void main(String[] args) throws IOException {
- System.out.println("Nimbus host " + nimbusHost);
- System.out.println("Nimbus thrift port " + nimbusThriftPort);
+ Config config = new Config();
+ config.putAll(Utils.readCommandLineOpts());
+ config.putAll(Utils.readStormConfig());
- System.out.println("uploading jar from " + args[0]);
- String uploadedJarLocation = StormSubmitter.submitJar(config, args[0]);
-
- System.out.println("Uploaded jar file location: ");
- System.out.println(uploadedJarLocation);
-
- Properties props = StormSamoaUtils.getProperties();
- props.setProperty(StormJarSubmitter.UPLOADED_JAR_LOCATION_KEY, uploadedJarLocation);
-
- File f = new File("src/main/resources/samoa-storm-cluster.properties");
- f.createNewFile();
-
- OutputStream out = new FileOutputStream(f);
- props.store(out, "properties file to store uploaded jar location from StormJarSubmitter");
- }
+ String nimbusHost = (String) config.get(Config.NIMBUS_HOST);
+ int nimbusThriftPort = Utils.getInt(config
+ .get(Config.NIMBUS_THRIFT_PORT));
+
+ System.out.println("Nimbus host " + nimbusHost);
+ System.out.println("Nimbus thrift port " + nimbusThriftPort);
+
+ System.out.println("uploading jar from " + args[0]);
+ String uploadedJarLocation = StormSubmitter.submitJar(config, args[0]);
+
+ System.out.println("Uploaded jar file location: ");
+ System.out.println(uploadedJarLocation);
+
+ Properties props = StormSamoaUtils.getProperties();
+ props.setProperty(StormJarSubmitter.UPLOADED_JAR_LOCATION_KEY, uploadedJarLocation);
+
+ File f = new File("src/main/resources/samoa-storm-cluster.properties");
+ f.createNewFile();
+
+ OutputStream out = new FileOutputStream(f);
+ props.store(out, "properties file to store uploaded jar location from StormJarSubmitter");
+ }
}
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormProcessingItem.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormProcessingItem.java
index 73879f6..1a9064c 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormProcessingItem.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormProcessingItem.java
@@ -44,127 +44,126 @@
/**
* ProcessingItem implementation for Storm.
+ *
* @author Arinto Murdopo
- *
+ *
*/
class StormProcessingItem extends AbstractProcessingItem implements StormTopologyNode {
- private final ProcessingItemBolt piBolt;
- private BoltDeclarer piBoltDeclarer;
-
- //TODO: should we put parallelism hint here?
- //imo, parallelism hint only declared when we add this PI in the topology
- //open for dicussion :p
-
- StormProcessingItem(Processor processor, int parallelismHint){
- this(processor, UUID.randomUUID().toString(), parallelismHint);
- }
-
- StormProcessingItem(Processor processor, String friendlyId, int parallelismHint){
- super(processor, parallelismHint);
- this.piBolt = new ProcessingItemBolt(processor);
- this.setName(friendlyId);
- }
+ private final ProcessingItemBolt piBolt;
+ private BoltDeclarer piBoltDeclarer;
- @Override
- protected ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme) {
- StormStream stormInputStream = (StormStream) inputStream;
- InputStreamId inputId = stormInputStream.getInputId();
-
- switch(scheme) {
- case SHUFFLE:
- piBoltDeclarer.shuffleGrouping(inputId.getComponentId(),inputId.getStreamId());
- break;
- case GROUP_BY_KEY:
- piBoltDeclarer.fieldsGrouping(
- inputId.getComponentId(),
- inputId.getStreamId(),
- new Fields(StormSamoaUtils.KEY_FIELD));
- break;
- case BROADCAST:
- piBoltDeclarer.allGrouping(
- inputId.getComponentId(),
- inputId.getStreamId());
- break;
- }
- return this;
- }
-
- @Override
- public void addToTopology(StormTopology topology, int parallelismHint) {
- if(piBoltDeclarer != null){
- //throw exception that one PI only belong to one topology
- }else{
- TopologyBuilder stormBuilder = topology.getStormBuilder();
- this.piBoltDeclarer = stormBuilder.setBolt(this.getName(),
- this.piBolt, parallelismHint);
- }
- }
+ // TODO: should we put parallelism hint here?
+ // imo, parallelism hint only declared when we add this PI in the topology
+ // open for dicussion :p
- @Override
- public StormStream createStream() {
- return piBolt.createStream(this.getName());
- }
+ StormProcessingItem(Processor processor, int parallelismHint) {
+ this(processor, UUID.randomUUID().toString(), parallelismHint);
+ }
- @Override
- public String getId() {
- return this.getName();
- }
-
- @Override
- public String toString() {
- StringBuilder sb = new StringBuilder(super.toString());
- sb.insert(0, String.format("id: %s, ", this.getName()));
- return sb.toString();
- }
-
- private final static class ProcessingItemBolt extends BaseRichBolt{
-
- private static final long serialVersionUID = -6637673741263199198L;
-
- private final Set<StormBoltStream> streams;
- private final Processor processor;
-
- private OutputCollector collector;
-
- ProcessingItemBolt(Processor processor){
- this.streams = new HashSet<StormBoltStream>();
- this.processor = processor;
- }
-
- @Override
- public void prepare(@SuppressWarnings("rawtypes") Map stormConf, TopologyContext context,
- OutputCollector collector) {
- this.collector = collector;
- //Processor and this class share the same instance of stream
- for(StormBoltStream stream: streams){
- stream.setCollector(this.collector);
- }
-
- this.processor.onCreate(context.getThisTaskId());
- }
+ StormProcessingItem(Processor processor, String friendlyId, int parallelismHint) {
+ super(processor, parallelismHint);
+ this.piBolt = new ProcessingItemBolt(processor);
+ this.setName(friendlyId);
+ }
- @Override
- public void execute(Tuple input) {
- Object sentObject = input.getValue(0);
- ContentEvent sentEvent = (ContentEvent)sentObject;
- processor.process(sentEvent);
- }
+ @Override
+ protected ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme) {
+ StormStream stormInputStream = (StormStream) inputStream;
+ InputStreamId inputId = stormInputStream.getInputId();
- @Override
- public void declareOutputFields(OutputFieldsDeclarer declarer) {
- for(StormStream stream: streams){
- declarer.declareStream(stream.getOutputId(),
- new Fields(StormSamoaUtils.CONTENT_EVENT_FIELD,
- StormSamoaUtils.KEY_FIELD));
- }
- }
-
- StormStream createStream(String piId){
- StormBoltStream stream = new StormBoltStream(piId);
- streams.add(stream);
- return stream;
- }
- }
+ switch (scheme) {
+ case SHUFFLE:
+ piBoltDeclarer.shuffleGrouping(inputId.getComponentId(), inputId.getStreamId());
+ break;
+ case GROUP_BY_KEY:
+ piBoltDeclarer.fieldsGrouping(
+ inputId.getComponentId(),
+ inputId.getStreamId(),
+ new Fields(StormSamoaUtils.KEY_FIELD));
+ break;
+ case BROADCAST:
+ piBoltDeclarer.allGrouping(
+ inputId.getComponentId(),
+ inputId.getStreamId());
+ break;
+ }
+ return this;
+ }
+
+ @Override
+ public void addToTopology(StormTopology topology, int parallelismHint) {
+ if (piBoltDeclarer != null) {
+ // throw exception that one PI only belong to one topology
+ } else {
+ TopologyBuilder stormBuilder = topology.getStormBuilder();
+ this.piBoltDeclarer = stormBuilder.setBolt(this.getName(),
+ this.piBolt, parallelismHint);
+ }
+ }
+
+ @Override
+ public StormStream createStream() {
+ return piBolt.createStream(this.getName());
+ }
+
+ @Override
+ public String getId() {
+ return this.getName();
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder(super.toString());
+ sb.insert(0, String.format("id: %s, ", this.getName()));
+ return sb.toString();
+ }
+
+ private final static class ProcessingItemBolt extends BaseRichBolt {
+
+ private static final long serialVersionUID = -6637673741263199198L;
+
+ private final Set<StormBoltStream> streams;
+ private final Processor processor;
+
+ private OutputCollector collector;
+
+ ProcessingItemBolt(Processor processor) {
+ this.streams = new HashSet<StormBoltStream>();
+ this.processor = processor;
+ }
+
+ @Override
+ public void prepare(@SuppressWarnings("rawtypes") Map stormConf, TopologyContext context,
+ OutputCollector collector) {
+ this.collector = collector;
+ // Processor and this class share the same instance of stream
+ for (StormBoltStream stream : streams) {
+ stream.setCollector(this.collector);
+ }
+
+ this.processor.onCreate(context.getThisTaskId());
+ }
+
+ @Override
+ public void execute(Tuple input) {
+ Object sentObject = input.getValue(0);
+ ContentEvent sentEvent = (ContentEvent) sentObject;
+ processor.process(sentEvent);
+ }
+
+ @Override
+ public void declareOutputFields(OutputFieldsDeclarer declarer) {
+ for (StormStream stream : streams) {
+ declarer.declareStream(stream.getOutputId(),
+ new Fields(StormSamoaUtils.CONTENT_EVENT_FIELD,
+ StormSamoaUtils.KEY_FIELD));
+ }
+ }
+
+ StormStream createStream(String piId) {
+ StormBoltStream stream = new StormBoltStream(piId);
+ streams.add(stream);
+ return stream;
+ }
+ }
}
-
-
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormSamoaUtils.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormSamoaUtils.java
index 7c4769e..d978a8f 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormSamoaUtils.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormSamoaUtils.java
@@ -34,75 +34,82 @@
import com.yahoo.labs.samoa.tasks.Task;
/**
- * Utility class for samoa-storm project. It is used by StormDoTask to process its arguments.
+ * Utility class for samoa-storm project. It is used by StormDoTask to process
+ * its arguments.
+ *
* @author Arinto Murdopo
- *
+ *
*/
public class StormSamoaUtils {
-
- private static final Logger logger = LoggerFactory.getLogger(StormSamoaUtils.class);
- static final String KEY_FIELD = "key";
- static final String CONTENT_EVENT_FIELD = "content_event";
-
- static Properties getProperties() throws IOException{
- Properties props = new Properties();
- InputStream is;
-
- File f = new File("src/main/resources/samoa-storm-cluster.properties"); // FIXME it does not exist anymore
- is = new FileInputStream(f);
-
- try {
- props.load(is);
- } catch (IOException e1) {
- System.out.println("Fail to load property file");
- return null;
- } finally{
- is.close();
- }
-
- return props;
- }
-
- public static StormTopology argsToTopology(String[] args){
- StringBuilder cliString = new StringBuilder();
- for (String arg : args) {
- cliString.append(" ").append(arg);
- }
- logger.debug("Command line string = {}", cliString.toString());
+ private static final Logger logger = LoggerFactory.getLogger(StormSamoaUtils.class);
- Task task = getTask(cliString.toString());
-
- //TODO: remove setFactory method with DynamicBinding
- task.setFactory(new StormComponentFactory());
- task.init();
+ static final String KEY_FIELD = "key";
+ static final String CONTENT_EVENT_FIELD = "content_event";
- return (StormTopology)task.getTopology();
- }
-
- public static int numWorkers(List<String> tmpArgs){
- int position = tmpArgs.size() - 1;
- int numWorkers;
-
- try {
- numWorkers = Integer.parseInt(tmpArgs.get(position));
- tmpArgs.remove(position);
- } catch (NumberFormatException e) {
- numWorkers = 4;
- }
-
- return numWorkers;
- }
+ static Properties getProperties() throws IOException {
+ Properties props = new Properties();
+ InputStream is;
- public static Task getTask(String cliString) {
- Task task = null;
- try {
- logger.debug("Providing task [{}]", cliString);
- task = ClassOption.cliStringToObject(cliString, Task.class, null);
- } catch (Exception e) {
- logger.warn("Fail in initializing the task!");
- e.printStackTrace();
- }
- return task;
+ File f = new File("src/main/resources/samoa-storm-cluster.properties"); // FIXME
+ // it
+ // does
+ // not
+ // exist
+ // anymore
+ is = new FileInputStream(f);
+
+ try {
+ props.load(is);
+ } catch (IOException e1) {
+ System.out.println("Fail to load property file");
+ return null;
+ } finally {
+ is.close();
}
+
+ return props;
+ }
+
+ public static StormTopology argsToTopology(String[] args) {
+ StringBuilder cliString = new StringBuilder();
+ for (String arg : args) {
+ cliString.append(" ").append(arg);
+ }
+ logger.debug("Command line string = {}", cliString.toString());
+
+ Task task = getTask(cliString.toString());
+
+ // TODO: remove setFactory method with DynamicBinding
+ task.setFactory(new StormComponentFactory());
+ task.init();
+
+ return (StormTopology) task.getTopology();
+ }
+
+ public static int numWorkers(List<String> tmpArgs) {
+ int position = tmpArgs.size() - 1;
+ int numWorkers;
+
+ try {
+ numWorkers = Integer.parseInt(tmpArgs.get(position));
+ tmpArgs.remove(position);
+ } catch (NumberFormatException e) {
+ numWorkers = 4;
+ }
+
+ return numWorkers;
+ }
+
+ public static Task getTask(String cliString) {
+ Task task = null;
+ try {
+ logger.debug("Providing task [{}]", cliString);
+ task = ClassOption.cliStringToObject(cliString, Task.class, null);
+ } catch (Exception e) {
+ logger.warn("Fail in initializing the task!");
+ e.printStackTrace();
+ }
+ return task;
+ }
}
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormSpoutStream.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormSpoutStream.java
index d066e42..06f5bb2 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormSpoutStream.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormSpoutStream.java
@@ -62,4 +62,4 @@
// return null;
// }
//
-//}
+// }
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormStream.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormStream.java
index f67ab19..ed39a50 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormStream.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormStream.java
@@ -27,59 +27,60 @@
/**
* Abstract class to implement Storm Stream
+ *
* @author Arinto Murdopo
- *
+ *
*/
abstract class StormStream implements Stream, java.io.Serializable {
-
- /**
+
+ /**
*
*/
- private static final long serialVersionUID = 281835563756514852L;
- protected final String outputStreamId;
- protected final InputStreamId inputStreamId;
-
- public StormStream(String stormComponentId){
- this.outputStreamId = UUID.randomUUID().toString();
- this.inputStreamId = new InputStreamId(stormComponentId, this.outputStreamId);
- }
-
- @Override
- public abstract void put(ContentEvent contentEvent);
-
- String getOutputId(){
- return this.outputStreamId;
- }
-
- InputStreamId getInputId(){
- return this.inputStreamId;
- }
-
- final static class InputStreamId implements java.io.Serializable{
-
- /**
+ private static final long serialVersionUID = 281835563756514852L;
+ protected final String outputStreamId;
+ protected final InputStreamId inputStreamId;
+
+ public StormStream(String stormComponentId) {
+ this.outputStreamId = UUID.randomUUID().toString();
+ this.inputStreamId = new InputStreamId(stormComponentId, this.outputStreamId);
+ }
+
+ @Override
+ public abstract void put(ContentEvent contentEvent);
+
+ String getOutputId() {
+ return this.outputStreamId;
+ }
+
+ InputStreamId getInputId() {
+ return this.inputStreamId;
+ }
+
+ final static class InputStreamId implements java.io.Serializable {
+
+ /**
*
*/
- private static final long serialVersionUID = -7457995634133691295L;
- private final String componentId;
- private final String streamId;
-
- InputStreamId(String componentId, String streamId){
- this.componentId = componentId;
- this.streamId = streamId;
- }
-
- String getComponentId(){
- return componentId;
- }
-
- String getStreamId(){
- return streamId;
- }
- }
-
- @Override
- public void setBatchSize(int batchSize) {
- // Ignore batch size
- }
+ private static final long serialVersionUID = -7457995634133691295L;
+ private final String componentId;
+ private final String streamId;
+
+ InputStreamId(String componentId, String streamId) {
+ this.componentId = componentId;
+ this.streamId = streamId;
+ }
+
+ String getComponentId() {
+ return componentId;
+ }
+
+ String getStreamId() {
+ return streamId;
+ }
+ }
+
+ @Override
+ public void setBatchSize(int batchSize) {
+ // Ignore batch size
+ }
}
\ No newline at end of file
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopology.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopology.java
index 7a49d8b..20995d5 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopology.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopology.java
@@ -27,26 +27,27 @@
/**
* Adaptation of SAMOA topology in samoa-storm
+ *
* @author Arinto Murdopo
- *
+ *
*/
public class StormTopology extends AbstractTopology {
-
- private TopologyBuilder builder;
-
- public StormTopology(String topologyName){
- super(topologyName);
- this.builder = new TopologyBuilder();
- }
-
- @Override
- public void addProcessingItem(IProcessingItem procItem, int parallelismHint){
- StormTopologyNode stormNode = (StormTopologyNode) procItem;
- stormNode.addToTopology(this, parallelismHint);
- super.addProcessingItem(procItem, parallelismHint);
- }
-
- public TopologyBuilder getStormBuilder(){
- return builder;
- }
+
+ private TopologyBuilder builder;
+
+ public StormTopology(String topologyName) {
+ super(topologyName);
+ this.builder = new TopologyBuilder();
+ }
+
+ @Override
+ public void addProcessingItem(IProcessingItem procItem, int parallelismHint) {
+ StormTopologyNode stormNode = (StormTopologyNode) procItem;
+ stormNode.addToTopology(this, parallelismHint);
+ super.addProcessingItem(procItem, parallelismHint);
+ }
+
+ public TopologyBuilder getStormBuilder() {
+ return builder;
+ }
}
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopologyNode.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopologyNode.java
index 07fccbf..8be3a1b 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopologyNode.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopologyNode.java
@@ -22,13 +22,16 @@
/**
* Interface to represent a node in samoa-storm topology.
+ *
* @author Arinto Murdopo
- *
+ *
*/
interface StormTopologyNode {
- void addToTopology(StormTopology topology, int parallelismHint);
- StormStream createStream();
- String getId();
-
+ void addToTopology(StormTopology topology, int parallelismHint);
+
+ StormStream createStream();
+
+ String getId();
+
}
diff --git a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopologySubmitter.java b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopologySubmitter.java
index 1e1b048..a4f1f51 100644
--- a/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopologySubmitter.java
+++ b/samoa-storm/src/main/java/com/yahoo/labs/samoa/topology/impl/StormTopologySubmitter.java
@@ -41,93 +41,95 @@
import backtype.storm.utils.Utils;
/**
- * Helper class to submit SAMOA task into Storm without the need of submitting the jar file.
- * The jar file must be submitted first using StormJarSubmitter class.
+ * Helper class to submit SAMOA task into Storm without the need of submitting
+ * the jar file. The jar file must be submitted first using StormJarSubmitter
+ * class.
+ *
* @author Arinto Murdopo
- *
+ *
*/
public class StormTopologySubmitter {
-
- public static String YJP_OPTIONS_KEY="YjpOptions";
-
- private static Logger logger = LoggerFactory.getLogger(StormTopologySubmitter.class);
-
- public static void main(String[] args) throws IOException{
- Properties props = StormSamoaUtils.getProperties();
-
- String uploadedJarLocation = props.getProperty(StormJarSubmitter.UPLOADED_JAR_LOCATION_KEY);
- if(uploadedJarLocation == null){
- logger.error("Invalid properties file. It must have key {}",
- StormJarSubmitter.UPLOADED_JAR_LOCATION_KEY);
- return;
- }
-
- List<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
- int numWorkers = StormSamoaUtils.numWorkers(tmpArgs);
-
- args = tmpArgs.toArray(new String[0]);
- StormTopology stormTopo = StormSamoaUtils.argsToTopology(args);
- Config conf = new Config();
- conf.putAll(Utils.readStormConfig());
- conf.putAll(Utils.readCommandLineOpts());
- conf.setDebug(false);
- conf.setNumWorkers(numWorkers);
-
- String profilerOption =
- props.getProperty(StormTopologySubmitter.YJP_OPTIONS_KEY);
- if(profilerOption != null){
- String topoWorkerChildOpts = (String) conf.get(Config.TOPOLOGY_WORKER_CHILDOPTS);
- StringBuilder optionBuilder = new StringBuilder();
- if(topoWorkerChildOpts != null){
- optionBuilder.append(topoWorkerChildOpts);
- optionBuilder.append(' ');
- }
- optionBuilder.append(profilerOption);
- conf.put(Config.TOPOLOGY_WORKER_CHILDOPTS, optionBuilder.toString());
- }
+ public static String YJP_OPTIONS_KEY = "YjpOptions";
- Map<String, Object> myConfigMap = new HashMap<String, Object>(conf);
- StringWriter out = new StringWriter();
+ private static Logger logger = LoggerFactory.getLogger(StormTopologySubmitter.class);
- try {
- JSONValue.writeJSONString(myConfigMap, out);
- } catch (IOException e) {
- System.out.println("Error in writing JSONString");
- e.printStackTrace();
- return;
- }
-
- Config config = new Config();
- config.putAll(Utils.readStormConfig());
-
- String nimbusHost = (String) config.get(Config.NIMBUS_HOST);
-
- NimbusClient nc = new NimbusClient(nimbusHost);
- String topologyName = stormTopo.getTopologyName();
- try {
- System.out.println("Submitting topology with name: "
- + topologyName);
- nc.getClient().submitTopology(topologyName, uploadedJarLocation,
- out.toString(), stormTopo.getStormBuilder().createTopology());
- System.out.println(topologyName + " is successfully submitted");
+ public static void main(String[] args) throws IOException {
+ Properties props = StormSamoaUtils.getProperties();
- } catch (AlreadyAliveException aae) {
- System.out.println("Fail to submit " + topologyName
- + "\nError message: " + aae.get_msg());
- } catch (InvalidTopologyException ite) {
- System.out.println("Invalid topology for " + topologyName);
- ite.printStackTrace();
- } catch (TException te) {
- System.out.println("Texception for " + topologyName);
- te.printStackTrace();
- }
- }
-
- private static String uploadedJarLocation(List<String> tmpArgs){
- int position = tmpArgs.size() - 1;
- String uploadedJarLocation = tmpArgs.get(position);
- tmpArgs.remove(position);
- return uploadedJarLocation;
- }
+ String uploadedJarLocation = props.getProperty(StormJarSubmitter.UPLOADED_JAR_LOCATION_KEY);
+ if (uploadedJarLocation == null) {
+ logger.error("Invalid properties file. It must have key {}",
+ StormJarSubmitter.UPLOADED_JAR_LOCATION_KEY);
+ return;
+ }
+
+ List<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
+ int numWorkers = StormSamoaUtils.numWorkers(tmpArgs);
+
+ args = tmpArgs.toArray(new String[0]);
+ StormTopology stormTopo = StormSamoaUtils.argsToTopology(args);
+
+ Config conf = new Config();
+ conf.putAll(Utils.readStormConfig());
+ conf.putAll(Utils.readCommandLineOpts());
+ conf.setDebug(false);
+ conf.setNumWorkers(numWorkers);
+
+ String profilerOption =
+ props.getProperty(StormTopologySubmitter.YJP_OPTIONS_KEY);
+ if (profilerOption != null) {
+ String topoWorkerChildOpts = (String) conf.get(Config.TOPOLOGY_WORKER_CHILDOPTS);
+ StringBuilder optionBuilder = new StringBuilder();
+ if (topoWorkerChildOpts != null) {
+ optionBuilder.append(topoWorkerChildOpts);
+ optionBuilder.append(' ');
+ }
+ optionBuilder.append(profilerOption);
+ conf.put(Config.TOPOLOGY_WORKER_CHILDOPTS, optionBuilder.toString());
+ }
+
+ Map<String, Object> myConfigMap = new HashMap<String, Object>(conf);
+ StringWriter out = new StringWriter();
+
+ try {
+ JSONValue.writeJSONString(myConfigMap, out);
+ } catch (IOException e) {
+ System.out.println("Error in writing JSONString");
+ e.printStackTrace();
+ return;
+ }
+
+ Config config = new Config();
+ config.putAll(Utils.readStormConfig());
+
+ String nimbusHost = (String) config.get(Config.NIMBUS_HOST);
+
+ NimbusClient nc = new NimbusClient(nimbusHost);
+ String topologyName = stormTopo.getTopologyName();
+ try {
+ System.out.println("Submitting topology with name: "
+ + topologyName);
+ nc.getClient().submitTopology(topologyName, uploadedJarLocation,
+ out.toString(), stormTopo.getStormBuilder().createTopology());
+ System.out.println(topologyName + " is successfully submitted");
+
+ } catch (AlreadyAliveException aae) {
+ System.out.println("Fail to submit " + topologyName
+ + "\nError message: " + aae.get_msg());
+ } catch (InvalidTopologyException ite) {
+ System.out.println("Invalid topology for " + topologyName);
+ ite.printStackTrace();
+ } catch (TException te) {
+ System.out.println("Texception for " + topologyName);
+ te.printStackTrace();
+ }
+ }
+
+ private static String uploadedJarLocation(List<String> tmpArgs) {
+ int position = tmpArgs.size() - 1;
+ String uploadedJarLocation = tmpArgs.get(position);
+ tmpArgs.remove(position);
+ return uploadedJarLocation;
+ }
}
diff --git a/samoa-storm/src/test/java/com/yahoo/labs/samoa/AlgosTest.java b/samoa-storm/src/test/java/com/yahoo/labs/samoa/AlgosTest.java
index 15b80b5..9f6089c 100644
--- a/samoa-storm/src/test/java/com/yahoo/labs/samoa/AlgosTest.java
+++ b/samoa-storm/src/test/java/com/yahoo/labs/samoa/AlgosTest.java
@@ -24,45 +24,43 @@
public class AlgosTest {
+ @Test(timeout = 60000)
+ public void testVHTWithStorm() throws Exception {
- @Test(timeout = 60000)
- public void testVHTWithStorm() throws Exception {
+ TestParams vhtConfig = new TestParams.Builder()
+ .inputInstances(200_000)
+ .samplingSize(20_000)
+ .evaluationInstances(200_000)
+ .classifiedInstances(200_000)
+ .classificationsCorrect(55f)
+ .kappaStat(0f)
+ .kappaTempStat(0f)
+ .cliStringTemplate(TestParams.Templates.PREQEVAL_VHT_RANDOMTREE)
+ .resultFilePollTimeout(30)
+ .prePollWait(15)
+ .taskClassName(LocalStormDoTask.class.getName())
+ .build();
+ TestUtils.test(vhtConfig);
- TestParams vhtConfig = new TestParams.Builder()
- .inputInstances(200_000)
- .samplingSize(20_000)
- .evaluationInstances(200_000)
- .classifiedInstances(200_000)
- .classificationsCorrect(55f)
- .kappaStat(0f)
- .kappaTempStat(0f)
- .cliStringTemplate(TestParams.Templates.PREQEVAL_VHT_RANDOMTREE)
- .resultFilePollTimeout(30)
- .prePollWait(15)
- .taskClassName(LocalStormDoTask.class.getName())
- .build();
- TestUtils.test(vhtConfig);
+ }
- }
+ @Test(timeout = 120000)
+ public void testBaggingWithStorm() throws Exception {
+ TestParams baggingConfig = new TestParams.Builder()
+ .inputInstances(200_000)
+ .samplingSize(20_000)
+ .evaluationInstances(180_000)
+ .classifiedInstances(190_000)
+ .classificationsCorrect(60f)
+ .kappaStat(0f)
+ .kappaTempStat(0f)
+ .cliStringTemplate(TestParams.Templates.PREQEVAL_BAGGING_RANDOMTREE)
+ .resultFilePollTimeout(40)
+ .prePollWait(20)
+ .taskClassName(LocalStormDoTask.class.getName())
+ .build();
+ TestUtils.test(baggingConfig);
- @Test(timeout = 120000)
- public void testBaggingWithStorm() throws Exception {
- TestParams baggingConfig = new TestParams.Builder()
- .inputInstances(200_000)
- .samplingSize(20_000)
- .evaluationInstances(180_000)
- .classifiedInstances(190_000)
- .classificationsCorrect(60f)
- .kappaStat(0f)
- .kappaTempStat(0f)
- .cliStringTemplate(TestParams.Templates.PREQEVAL_BAGGING_RANDOMTREE)
- .resultFilePollTimeout(40)
- .prePollWait(20)
- .taskClassName(LocalStormDoTask.class.getName())
- .build();
- TestUtils.test(baggingConfig);
-
- }
-
+ }
}
diff --git a/samoa-storm/src/test/java/com/yahoo/labs/samoa/topology/impl/StormProcessingItemTest.java b/samoa-storm/src/test/java/com/yahoo/labs/samoa/topology/impl/StormProcessingItemTest.java
index ec8929a..d233ca6 100644
--- a/samoa-storm/src/test/java/com/yahoo/labs/samoa/topology/impl/StormProcessingItemTest.java
+++ b/samoa-storm/src/test/java/com/yahoo/labs/samoa/topology/impl/StormProcessingItemTest.java
@@ -38,41 +38,46 @@
import com.yahoo.labs.samoa.core.Processor;
public class StormProcessingItemTest {
- private static final int PARRALLELISM_HINT_2 = 2;
- private static final int PARRALLELISM_HINT_4 = 4;
- private static final String ID = "id";
- @Tested private StormProcessingItem pi;
- @Mocked private Processor processor;
- @Mocked private StormTopology topology;
- @Mocked private TopologyBuilder stormBuilder = new TopologyBuilder();
+ private static final int PARRALLELISM_HINT_2 = 2;
+ private static final int PARRALLELISM_HINT_4 = 4;
+ private static final String ID = "id";
+ @Tested
+ private StormProcessingItem pi;
+ @Mocked
+ private Processor processor;
+ @Mocked
+ private StormTopology topology;
+ @Mocked
+ private TopologyBuilder stormBuilder = new TopologyBuilder();
- @Before
- public void setUp() {
- pi = new StormProcessingItem(processor, ID, PARRALLELISM_HINT_2);
- }
+ @Before
+ public void setUp() {
+ pi = new StormProcessingItem(processor, ID, PARRALLELISM_HINT_2);
+ }
- @Test
- public void testAddToTopology() {
- new Expectations() {
- {
- topology.getStormBuilder();
- result = stormBuilder;
+ @Test
+ public void testAddToTopology() {
+ new Expectations() {
+ {
+ topology.getStormBuilder();
+ result = stormBuilder;
- stormBuilder.setBolt(ID, (IRichBolt) any, anyInt);
- result = new MockUp<BoltDeclarer>() {
- }.getMockInstance();
- }
- };
+ stormBuilder.setBolt(ID, (IRichBolt) any, anyInt);
+ result = new MockUp<BoltDeclarer>() {
+ }.getMockInstance();
+ }
+ };
- pi.addToTopology(topology, PARRALLELISM_HINT_4); // this parallelism hint is ignored
+ pi.addToTopology(topology, PARRALLELISM_HINT_4); // this parallelism hint is
+ // ignored
- new Verifications() {
- {
- assertEquals(pi.getProcessor(), processor);
- // TODO add methods to explore a topology and verify them
- assertEquals(pi.getParallelism(), PARRALLELISM_HINT_2);
- assertEquals(pi.getId(), ID);
- }
- };
- }
+ new Verifications() {
+ {
+ assertEquals(pi.getProcessor(), processor);
+ // TODO add methods to explore a topology and verify them
+ assertEquals(pi.getParallelism(), PARRALLELISM_HINT_2);
+ assertEquals(pi.getId(), ID);
+ }
+ };
+ }
}
diff --git a/samoa-test/src/test/java/com/yahoo/labs/samoa/TestParams.java b/samoa-test/src/test/java/com/yahoo/labs/samoa/TestParams.java
index 08ad94f..26b6b74 100644
--- a/samoa-test/src/test/java/com/yahoo/labs/samoa/TestParams.java
+++ b/samoa-test/src/test/java/com/yahoo/labs/samoa/TestParams.java
@@ -2,234 +2,234 @@
public class TestParams {
- /**
- * templates that take the following parameters:
- * <ul>
- * <li>the output file location as an argument (-d),
- * <li>the maximum number of instances for testing/training (-i)
- * <li>the sampling size (-f)
- * <li>the delay in ms between input instances (-w) , default is zero
- * </ul>
- * as well as the maximum number of instances for testing/training (-i) and the sampling size (-f)
- */
- public static class Templates {
+ /**
+ * templates that take the following parameters:
+ * <ul>
+ * <li>the output file location as an argument (-d),
+ * <li>the maximum number of instances for testing/training (-i)
+ * <li>the sampling size (-f)
+ * <li>the delay in ms between input instances (-w) , default is zero
+ * </ul>
+ * as well as the maximum number of instances for testing/training (-i) and
+ * the sampling size (-f)
+ */
+ public static class Templates {
- public final static String PREQEVAL_VHT_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d "
- + "-l (com.yahoo.labs.samoa.learners.classifiers.trees.VerticalHoeffdingTree -p 4) " +
- "-s (com.yahoo.labs.samoa.moa.streams.generators.RandomTreeGenerator -c 2 -o 10 -u 10)";
+ public final static String PREQEVAL_VHT_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d "
+ + "-l (com.yahoo.labs.samoa.learners.classifiers.trees.VerticalHoeffdingTree -p 4) " +
+ "-s (com.yahoo.labs.samoa.moa.streams.generators.RandomTreeGenerator -c 2 -o 10 -u 10)";
- public final static String PREQEVAL_NAIVEBAYES_HYPERPLANE = "PrequentialEvaluation -d %s -i %d -f %d -w %d "
- + "-l (classifiers.SingleClassifier -l com.yahoo.labs.samoa.learners.classifiers.NaiveBayes) " +
- "-s (com.yahoo.labs.samoa.moa.streams.generators.HyperplaneGenerator -c 2)";
+ public final static String PREQEVAL_NAIVEBAYES_HYPERPLANE = "PrequentialEvaluation -d %s -i %d -f %d -w %d "
+ + "-l (classifiers.SingleClassifier -l com.yahoo.labs.samoa.learners.classifiers.NaiveBayes) " +
+ "-s (com.yahoo.labs.samoa.moa.streams.generators.HyperplaneGenerator -c 2)";
- // setting the number of nominal attributes to zero significantly reduces the processing time,
- // so that it's acceptable in a test case
- public final static String PREQEVAL_BAGGING_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d "
- + "-l (com.yahoo.labs.samoa.learners.classifiers.ensemble.Bagging) " +
- "-s (com.yahoo.labs.samoa.moa.streams.generators.RandomTreeGenerator -c 2 -o 0 -u 10)";
+ // setting the number of nominal attributes to zero significantly reduces
+ // the processing time,
+ // so that it's acceptable in a test case
+ public final static String PREQEVAL_BAGGING_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d "
+ + "-l (com.yahoo.labs.samoa.learners.classifiers.ensemble.Bagging) " +
+ "-s (com.yahoo.labs.samoa.moa.streams.generators.RandomTreeGenerator -c 2 -o 0 -u 10)";
- }
+ }
+ public static final String EVALUATION_INSTANCES = "evaluation instances";
+ public static final String CLASSIFIED_INSTANCES = "classified instances";
+ public static final String CLASSIFICATIONS_CORRECT = "classifications correct (percent)";
+ public static final String KAPPA_STAT = "Kappa Statistic (percent)";
+ public static final String KAPPA_TEMP_STAT = "Kappa Temporal Statistic (percent)";
- public static final String EVALUATION_INSTANCES = "evaluation instances";
- public static final String CLASSIFIED_INSTANCES = "classified instances";
- public static final String CLASSIFICATIONS_CORRECT = "classifications correct (percent)";
- public static final String KAPPA_STAT = "Kappa Statistic (percent)";
- public static final String KAPPA_TEMP_STAT = "Kappa Temporal Statistic (percent)";
+ private long inputInstances;
+ private long samplingSize;
+ private long evaluationInstances;
+ private long classifiedInstances;
+ private float classificationsCorrect;
+ private float kappaStat;
+ private float kappaTempStat;
+ private String cliStringTemplate;
+ private int pollTimeoutSeconds;
+ private final int prePollWait;
+ private int inputDelayMicroSec;
+ private String taskClassName;
+ private TestParams(String taskClassName,
+ long inputInstances,
+ long samplingSize,
+ long evaluationInstances,
+ long classifiedInstances,
+ float classificationsCorrect,
+ float kappaStat,
+ float kappaTempStat,
+ String cliStringTemplate,
+ int pollTimeoutSeconds,
+ int prePollWait,
+ int inputDelayMicroSec) {
+ this.taskClassName = taskClassName;
+ this.inputInstances = inputInstances;
+ this.samplingSize = samplingSize;
+ this.evaluationInstances = evaluationInstances;
+ this.classifiedInstances = classifiedInstances;
+ this.classificationsCorrect = classificationsCorrect;
+ this.kappaStat = kappaStat;
+ this.kappaTempStat = kappaTempStat;
+ this.cliStringTemplate = cliStringTemplate;
+ this.pollTimeoutSeconds = pollTimeoutSeconds;
+ this.prePollWait = prePollWait;
+ this.inputDelayMicroSec = inputDelayMicroSec;
+ }
+ public String getTaskClassName() {
+ return taskClassName;
+ }
+
+ public long getInputInstances() {
+ return inputInstances;
+ }
+
+ public long getSamplingSize() {
+ return samplingSize;
+ }
+
+ public int getPollTimeoutSeconds() {
+ return pollTimeoutSeconds;
+ }
+
+ public int getPrePollWaitSeconds() {
+ return prePollWait;
+ }
+
+ public String getCliStringTemplate() {
+ return cliStringTemplate;
+ }
+
+ public long getEvaluationInstances() {
+ return evaluationInstances;
+ }
+
+ public long getClassifiedInstances() {
+ return classifiedInstances;
+ }
+
+ public float getClassificationsCorrect() {
+ return classificationsCorrect;
+ }
+
+ public float getKappaStat() {
+ return kappaStat;
+ }
+
+ public float getKappaTempStat() {
+ return kappaTempStat;
+ }
+
+ public int getInputDelayMicroSec() {
+ return inputDelayMicroSec;
+ }
+
+ @Override
+ public String toString() {
+ return "TestParams{\n" +
+ "inputInstances=" + inputInstances + "\n" +
+ "samplingSize=" + samplingSize + "\n" +
+ "evaluationInstances=" + evaluationInstances + "\n" +
+ "classifiedInstances=" + classifiedInstances + "\n" +
+ "classificationsCorrect=" + classificationsCorrect + "\n" +
+ "kappaStat=" + kappaStat + "\n" +
+ "kappaTempStat=" + kappaTempStat + "\n" +
+ "cliStringTemplate='" + cliStringTemplate + '\'' + "\n" +
+ "pollTimeoutSeconds=" + pollTimeoutSeconds + "\n" +
+ "prePollWait=" + prePollWait + "\n" +
+ "taskClassName='" + taskClassName + '\'' + "\n" +
+ "inputDelayMicroSec=" + inputDelayMicroSec + "\n" +
+ '}';
+ }
+
+ public static class Builder {
private long inputInstances;
private long samplingSize;
private long evaluationInstances;
private long classifiedInstances;
private float classificationsCorrect;
- private float kappaStat;
- private float kappaTempStat;
+ private float kappaStat = 0f;
+ private float kappaTempStat = 0f;
private String cliStringTemplate;
- private int pollTimeoutSeconds;
- private final int prePollWait;
- private int inputDelayMicroSec;
+ private int pollTimeoutSeconds = 10;
+ private int prePollWaitSeconds = 10;
private String taskClassName;
+ private int inputDelayMicroSec = 0;
- private TestParams(String taskClassName,
- long inputInstances,
- long samplingSize,
- long evaluationInstances,
- long classifiedInstances,
- float classificationsCorrect,
- float kappaStat,
- float kappaTempStat,
- String cliStringTemplate,
- int pollTimeoutSeconds,
- int prePollWait,
- int inputDelayMicroSec) {
- this.taskClassName = taskClassName;
- this.inputInstances = inputInstances;
- this.samplingSize = samplingSize;
- this.evaluationInstances = evaluationInstances;
- this.classifiedInstances = classifiedInstances;
- this.classificationsCorrect = classificationsCorrect;
- this.kappaStat = kappaStat;
- this.kappaTempStat = kappaTempStat;
- this.cliStringTemplate = cliStringTemplate;
- this.pollTimeoutSeconds = pollTimeoutSeconds;
- this.prePollWait = prePollWait;
- this.inputDelayMicroSec = inputDelayMicroSec;
+ public Builder taskClassName(String taskClassName) {
+ this.taskClassName = taskClassName;
+ return this;
}
- public String getTaskClassName() {
- return taskClassName;
+ public Builder inputInstances(long inputInstances) {
+ this.inputInstances = inputInstances;
+ return this;
}
- public long getInputInstances() {
- return inputInstances;
+ public Builder samplingSize(long samplingSize) {
+ this.samplingSize = samplingSize;
+ return this;
}
- public long getSamplingSize() {
- return samplingSize;
+ public Builder evaluationInstances(long evaluationInstances) {
+ this.evaluationInstances = evaluationInstances;
+ return this;
}
- public int getPollTimeoutSeconds() {
- return pollTimeoutSeconds;
+ public Builder classifiedInstances(long classifiedInstances) {
+ this.classifiedInstances = classifiedInstances;
+ return this;
}
- public int getPrePollWaitSeconds() {
- return prePollWait;
+ public Builder classificationsCorrect(float classificationsCorrect) {
+ this.classificationsCorrect = classificationsCorrect;
+ return this;
}
- public String getCliStringTemplate() {
- return cliStringTemplate;
+ public Builder kappaStat(float kappaStat) {
+ this.kappaStat = kappaStat;
+ return this;
}
- public long getEvaluationInstances() {
- return evaluationInstances;
+ public Builder kappaTempStat(float kappaTempStat) {
+ this.kappaTempStat = kappaTempStat;
+ return this;
}
- public long getClassifiedInstances() {
- return classifiedInstances;
+ public Builder cliStringTemplate(String cliStringTemplate) {
+ this.cliStringTemplate = cliStringTemplate;
+ return this;
}
- public float getClassificationsCorrect() {
- return classificationsCorrect;
+ public Builder resultFilePollTimeout(int pollTimeoutSeconds) {
+ this.pollTimeoutSeconds = pollTimeoutSeconds;
+ return this;
}
- public float getKappaStat() {
- return kappaStat;
+ public Builder inputDelayMicroSec(int inputDelayMicroSec) {
+ this.inputDelayMicroSec = inputDelayMicroSec;
+ return this;
}
- public float getKappaTempStat() {
- return kappaTempStat;
+ public Builder prePollWait(int prePollWaitSeconds) {
+ this.prePollWaitSeconds = prePollWaitSeconds;
+ return this;
}
- public int getInputDelayMicroSec() {
- return inputDelayMicroSec;
+ public TestParams build() {
+ return new TestParams(taskClassName,
+ inputInstances,
+ samplingSize,
+ evaluationInstances,
+ classifiedInstances,
+ classificationsCorrect,
+ kappaStat,
+ kappaTempStat,
+ cliStringTemplate,
+ pollTimeoutSeconds,
+ prePollWaitSeconds,
+ inputDelayMicroSec);
}
-
- @Override
- public String toString() {
- return "TestParams{\n" +
- "inputInstances=" + inputInstances + "\n" +
- "samplingSize=" + samplingSize + "\n" +
- "evaluationInstances=" + evaluationInstances + "\n" +
- "classifiedInstances=" + classifiedInstances + "\n" +
- "classificationsCorrect=" + classificationsCorrect + "\n" +
- "kappaStat=" + kappaStat + "\n" +
- "kappaTempStat=" + kappaTempStat + "\n" +
- "cliStringTemplate='" + cliStringTemplate + '\'' + "\n" +
- "pollTimeoutSeconds=" + pollTimeoutSeconds + "\n" +
- "prePollWait=" + prePollWait + "\n" +
- "taskClassName='" + taskClassName + '\'' + "\n" +
- "inputDelayMicroSec=" + inputDelayMicroSec + "\n" +
- '}';
- }
-
- public static class Builder {
- private long inputInstances;
- private long samplingSize;
- private long evaluationInstances;
- private long classifiedInstances;
- private float classificationsCorrect;
- private float kappaStat =0f;
- private float kappaTempStat =0f;
- private String cliStringTemplate;
- private int pollTimeoutSeconds = 10;
- private int prePollWaitSeconds = 10;
- private String taskClassName;
- private int inputDelayMicroSec = 0;
-
- public Builder taskClassName(String taskClassName) {
- this.taskClassName = taskClassName;
- return this;
- }
-
- public Builder inputInstances(long inputInstances) {
- this.inputInstances = inputInstances;
- return this;
- }
-
- public Builder samplingSize(long samplingSize) {
- this.samplingSize = samplingSize;
- return this;
- }
-
- public Builder evaluationInstances(long evaluationInstances) {
- this.evaluationInstances = evaluationInstances;
- return this;
- }
-
- public Builder classifiedInstances(long classifiedInstances) {
- this.classifiedInstances = classifiedInstances;
- return this;
- }
-
- public Builder classificationsCorrect(float classificationsCorrect) {
- this.classificationsCorrect = classificationsCorrect;
- return this;
- }
-
- public Builder kappaStat(float kappaStat) {
- this.kappaStat = kappaStat;
- return this;
- }
-
- public Builder kappaTempStat(float kappaTempStat) {
- this.kappaTempStat = kappaTempStat;
- return this;
- }
-
- public Builder cliStringTemplate(String cliStringTemplate) {
- this.cliStringTemplate = cliStringTemplate;
- return this;
- }
-
- public Builder resultFilePollTimeout(int pollTimeoutSeconds) {
- this.pollTimeoutSeconds = pollTimeoutSeconds;
- return this;
- }
-
- public Builder inputDelayMicroSec(int inputDelayMicroSec) {
- this.inputDelayMicroSec = inputDelayMicroSec;
- return this;
- }
-
- public Builder prePollWait(int prePollWaitSeconds) {
- this.prePollWaitSeconds = prePollWaitSeconds;
- return this;
- }
-
- public TestParams build() {
- return new TestParams(taskClassName,
- inputInstances,
- samplingSize,
- evaluationInstances,
- classifiedInstances,
- classificationsCorrect,
- kappaStat,
- kappaTempStat,
- cliStringTemplate,
- pollTimeoutSeconds,
- prePollWaitSeconds,
- inputDelayMicroSec);
- }
- }
+ }
}
diff --git a/samoa-test/src/test/java/com/yahoo/labs/samoa/TestUtils.java b/samoa-test/src/test/java/com/yahoo/labs/samoa/TestUtils.java
index d66f5df..c36706c 100644
--- a/samoa-test/src/test/java/com/yahoo/labs/samoa/TestUtils.java
+++ b/samoa-test/src/test/java/com/yahoo/labs/samoa/TestUtils.java
@@ -41,113 +41,112 @@
public class TestUtils {
- private static final Logger LOG = LoggerFactory.getLogger(TestUtils.class.getName());
+ private static final Logger LOG = LoggerFactory.getLogger(TestUtils.class.getName());
+ public static void test(final TestParams testParams) throws IOException, ClassNotFoundException,
+ NoSuchMethodException, InvocationTargetException, IllegalAccessException, InterruptedException {
- public static void test(final TestParams testParams) throws IOException, ClassNotFoundException, NoSuchMethodException, InvocationTargetException, IllegalAccessException, InterruptedException {
+ final File tempFile = File.createTempFile("test", "test");
- final File tempFile = File.createTempFile("test", "test");
+ LOG.info("Starting test, output file is {}, test config is \n{}", tempFile.getAbsolutePath(), testParams.toString());
- LOG.info("Starting test, output file is {}, test config is \n{}", tempFile.getAbsolutePath(), testParams.toString());
+ Executors.newSingleThreadExecutor().submit(new Callable<Void>() {
- Executors.newSingleThreadExecutor().submit(new Callable<Void>() {
+ @Override
+ public Void call() throws Exception {
+ try {
+ Class.forName(testParams.getTaskClassName())
+ .getMethod("main", String[].class)
+ .invoke(null, (Object) String.format(
+ testParams.getCliStringTemplate(),
+ tempFile.getAbsolutePath(),
+ testParams.getInputInstances(),
+ testParams.getSamplingSize(),
+ testParams.getInputDelayMicroSec()
+ ).split("[ ]"));
+ } catch (Exception e) {
+ LOG.error("Cannot execute test {} {}", e.getMessage(), e.getCause().getMessage());
+ }
+ return null;
+ }
+ });
- @Override
- public Void call() throws Exception {
- try {
- Class.forName(testParams.getTaskClassName())
- .getMethod("main", String[].class)
- .invoke(null, (Object) String.format(
- testParams.getCliStringTemplate(),
- tempFile.getAbsolutePath(),
- testParams.getInputInstances(),
- testParams.getSamplingSize(),
- testParams.getInputDelayMicroSec()
- ).split("[ ]"));
- } catch (Exception e) {
- LOG.error("Cannot execute test {} {}", e.getMessage(), e.getCause().getMessage());
- }
- return null;
- }
- });
+ Thread.sleep(TimeUnit.SECONDS.toMillis(testParams.getPrePollWaitSeconds()));
- Thread.sleep(TimeUnit.SECONDS.toMillis(testParams.getPrePollWaitSeconds()));
+ CountDownLatch signalComplete = new CountDownLatch(1);
- CountDownLatch signalComplete = new CountDownLatch(1);
+ final Tailer tailer = Tailer.create(tempFile, new TestResultsTailerAdapter(signalComplete), 1000);
+ new Thread(new Runnable() {
+ @Override
+ public void run() {
+ tailer.run();
+ }
+ }).start();
- final Tailer tailer = Tailer.create(tempFile, new TestResultsTailerAdapter(signalComplete), 1000);
- new Thread(new Runnable() {
- @Override
- public void run() {
- tailer.run();
- }
- }).start();
+ signalComplete.await();
+ tailer.stop();
- signalComplete.await();
- tailer.stop();
+ assertResults(tempFile, testParams);
+ }
- assertResults(tempFile, testParams);
+ public static void assertResults(File outputFile, com.yahoo.labs.samoa.TestParams testParams) throws IOException {
+
+ LOG.info("Checking results file " + outputFile.getAbsolutePath());
+ // 1. parse result file with csv parser
+ Reader in = new FileReader(outputFile);
+ Iterable<CSVRecord> records = CSVFormat.EXCEL.withSkipHeaderRecord(false)
+ .withIgnoreEmptyLines(true).withDelimiter(',').withCommentMarker('#').parse(in);
+ CSVRecord last = null;
+ Iterator<CSVRecord> iterator = records.iterator();
+ CSVRecord header = iterator.next();
+ Assert.assertEquals("Invalid number of columns", 5, header.size());
+
+ Assert
+ .assertEquals("Unexpected column", com.yahoo.labs.samoa.TestParams.EVALUATION_INSTANCES, header.get(0).trim());
+ Assert
+ .assertEquals("Unexpected column", com.yahoo.labs.samoa.TestParams.CLASSIFIED_INSTANCES, header.get(1).trim());
+ Assert.assertEquals("Unexpected column", com.yahoo.labs.samoa.TestParams.CLASSIFICATIONS_CORRECT, header.get(2)
+ .trim());
+ Assert.assertEquals("Unexpected column", com.yahoo.labs.samoa.TestParams.KAPPA_STAT, header.get(3).trim());
+ Assert.assertEquals("Unexpected column", com.yahoo.labs.samoa.TestParams.KAPPA_TEMP_STAT, header.get(4).trim());
+
+ // 2. check last line result
+ while (iterator.hasNext()) {
+ last = iterator.next();
}
- public static void assertResults(File outputFile, com.yahoo.labs.samoa.TestParams testParams) throws IOException {
+ assertTrue(String.format("Unmet threshold expected %d got %f",
+ testParams.getEvaluationInstances(), Float.parseFloat(last.get(0))),
+ testParams.getEvaluationInstances() <= Float.parseFloat(last.get(0)));
+ assertTrue(String.format("Unmet threshold expected %d got %f", testParams.getClassifiedInstances(),
+ Float.parseFloat(last.get(1))),
+ testParams.getClassifiedInstances() <= Float.parseFloat(last.get(1)));
+ assertTrue(String.format("Unmet threshold expected %f got %f",
+ testParams.getClassificationsCorrect(), Float.parseFloat(last.get(2))),
+ testParams.getClassificationsCorrect() <= Float.parseFloat(last.get(2)));
+ assertTrue(String.format("Unmet threshold expected %f got %f",
+ testParams.getKappaStat(), Float.parseFloat(last.get(3))),
+ testParams.getKappaStat() <= Float.parseFloat(last.get(3)));
+ assertTrue(String.format("Unmet threshold expected %f got %f",
+ testParams.getKappaTempStat(), Float.parseFloat(last.get(4))),
+ testParams.getKappaTempStat() <= Float.parseFloat(last.get(4)));
- LOG.info("Checking results file " + outputFile.getAbsolutePath());
- // 1. parse result file with csv parser
- Reader in = new FileReader(outputFile);
- Iterable<CSVRecord> records = CSVFormat.EXCEL.withSkipHeaderRecord(false)
- .withIgnoreEmptyLines(true).withDelimiter(',').withCommentMarker('#').parse(in);
- CSVRecord last = null;
- Iterator<CSVRecord> iterator = records.iterator();
- CSVRecord header = iterator.next();
- Assert.assertEquals("Invalid number of columns", 5, header.size());
+ }
- Assert.assertEquals("Unexpected column", com.yahoo.labs.samoa.TestParams.EVALUATION_INSTANCES, header.get(0).trim());
- Assert.assertEquals("Unexpected column", com.yahoo.labs.samoa.TestParams.CLASSIFIED_INSTANCES, header.get(1).trim());
- Assert.assertEquals("Unexpected column", com.yahoo.labs.samoa.TestParams.CLASSIFICATIONS_CORRECT, header.get(2).trim());
- Assert.assertEquals("Unexpected column", com.yahoo.labs.samoa.TestParams.KAPPA_STAT, header.get(3).trim());
- Assert.assertEquals("Unexpected column", com.yahoo.labs.samoa.TestParams.KAPPA_TEMP_STAT, header.get(4).trim());
+ private static class TestResultsTailerAdapter extends TailerListenerAdapter {
- // 2. check last line result
- while (iterator.hasNext()) {
- last = iterator.next();
- }
+ private final CountDownLatch signalComplete;
- assertTrue(String.format("Unmet threshold expected %d got %f",
- testParams.getEvaluationInstances(), Float.parseFloat(last.get(0))),
- testParams.getEvaluationInstances() <= Float.parseFloat(last.get(0)));
- assertTrue(String.format("Unmet threshold expected %d got %f", testParams.getClassifiedInstances(),
- Float.parseFloat(last.get(1))),
- testParams.getClassifiedInstances() <= Float.parseFloat(last.get(1)));
- assertTrue(String.format("Unmet threshold expected %f got %f",
- testParams.getClassificationsCorrect(), Float.parseFloat(last.get(2))),
- testParams.getClassificationsCorrect() <= Float.parseFloat(last.get(2)));
- assertTrue(String.format("Unmet threshold expected %f got %f",
- testParams.getKappaStat(), Float.parseFloat(last.get(3))),
- testParams.getKappaStat() <= Float.parseFloat(last.get(3)));
- assertTrue(String.format("Unmet threshold expected %f got %f",
- testParams.getKappaTempStat(), Float.parseFloat(last.get(4))),
- testParams.getKappaTempStat() <= Float.parseFloat(last.get(4)));
-
+ public TestResultsTailerAdapter(CountDownLatch signalComplete) {
+ this.signalComplete = signalComplete;
}
-
- private static class TestResultsTailerAdapter extends TailerListenerAdapter {
-
- private final CountDownLatch signalComplete;
-
- public TestResultsTailerAdapter(CountDownLatch signalComplete) {
- this.signalComplete = signalComplete;
- }
-
- @Override
- public void handle(String line) {
- if ("# COMPLETED".equals(line.trim())) {
- signalComplete.countDown();
- }
- }
+ @Override
+ public void handle(String line) {
+ if ("# COMPLETED".equals(line.trim())) {
+ signalComplete.countDown();
+ }
}
-
-
-
+ }
}
diff --git a/samoa-threads/src/main/java/com/yahoo/labs/samoa/LocalThreadsDoTask.java b/samoa-threads/src/main/java/com/yahoo/labs/samoa/LocalThreadsDoTask.java
index 21ccf9e..2ac9ec1 100644
--- a/samoa-threads/src/main/java/com/yahoo/labs/samoa/LocalThreadsDoTask.java
+++ b/samoa-threads/src/main/java/com/yahoo/labs/samoa/LocalThreadsDoTask.java
@@ -13,58 +13,58 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class LocalThreadsDoTask {
- private static final Logger logger = LoggerFactory.getLogger(LocalThreadsDoTask.class);
+ private static final Logger logger = LoggerFactory.getLogger(LocalThreadsDoTask.class);
- /**
- * The main method.
- *
- * @param args
- * the arguments
- */
- public static void main(String[] args) {
+ /**
+ * The main method.
+ *
+ * @param args
+ * the arguments
+ */
+ public static void main(String[] args) {
- ArrayList<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
-
- // Get number of threads for multithreading mode
- int numThreads = 1;
- for (int i=0; i<tmpArgs.size()-1; i++) {
- if (tmpArgs.get(i).equals("-t")) {
- try {
- numThreads = Integer.parseInt(tmpArgs.get(i+1));
- tmpArgs.remove(i+1);
- tmpArgs.remove(i);
- } catch (NumberFormatException e) {
- System.err.println("Invalid number of threads.");
- System.err.println(e.getStackTrace());
- }
- }
- }
- logger.info("Number of threads:{}", numThreads);
-
- args = tmpArgs.toArray(new String[0]);
+ ArrayList<String> tmpArgs = new ArrayList<String>(Arrays.asList(args));
- StringBuilder cliString = new StringBuilder();
- for (int i = 0; i < args.length; i++) {
- cliString.append(" ").append(args[i]);
- }
- logger.debug("Command line string = {}", cliString.toString());
- System.out.println("Command line string = " + cliString.toString());
-
- Task task = null;
+ // Get number of threads for multithreading mode
+ int numThreads = 1;
+ for (int i = 0; i < tmpArgs.size() - 1; i++) {
+ if (tmpArgs.get(i).equals("-t")) {
try {
- task = (Task) ClassOption.cliStringToObject(cliString.toString(), Task.class, null);
- logger.info("Sucessfully instantiating {}", task.getClass().getCanonicalName());
- } catch (Exception e) {
- logger.error("Fail to initialize the task", e);
- System.out.println("Fail to initialize the task" + e);
- return;
+ numThreads = Integer.parseInt(tmpArgs.get(i + 1));
+ tmpArgs.remove(i + 1);
+ tmpArgs.remove(i);
+ } catch (NumberFormatException e) {
+ System.err.println("Invalid number of threads.");
+ System.err.println(e.getStackTrace());
}
- task.setFactory(new ThreadsComponentFactory());
- task.init();
-
- ThreadsEngine.submitTopology(task.getTopology(), numThreads);
+ }
}
+ logger.info("Number of threads:{}", numThreads);
+
+ args = tmpArgs.toArray(new String[0]);
+
+ StringBuilder cliString = new StringBuilder();
+ for (int i = 0; i < args.length; i++) {
+ cliString.append(" ").append(args[i]);
+ }
+ logger.debug("Command line string = {}", cliString.toString());
+ System.out.println("Command line string = " + cliString.toString());
+
+ Task task = null;
+ try {
+ task = (Task) ClassOption.cliStringToObject(cliString.toString(), Task.class, null);
+ logger.info("Sucessfully instantiating {}", task.getClass().getCanonicalName());
+ } catch (Exception e) {
+ logger.error("Fail to initialize the task", e);
+ System.out.println("Fail to initialize the task" + e);
+ return;
+ }
+ task.setFactory(new ThreadsComponentFactory());
+ task.init();
+
+ ThreadsEngine.submitTopology(task.getTopology(), numThreads);
+ }
}
diff --git a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsComponentFactory.java b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsComponentFactory.java
index ac68da2..91f213b 100644
--- a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsComponentFactory.java
+++ b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsComponentFactory.java
@@ -31,34 +31,35 @@
/**
* ComponentFactory for multithreaded engine
+ *
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsComponentFactory implements ComponentFactory {
- @Override
- public ProcessingItem createPi(Processor processor) {
- return this.createPi(processor, 1);
- }
+ @Override
+ public ProcessingItem createPi(Processor processor) {
+ return this.createPi(processor, 1);
+ }
- @Override
- public ProcessingItem createPi(Processor processor, int paralellism) {
- return new ThreadsProcessingItem(processor, paralellism);
- }
+ @Override
+ public ProcessingItem createPi(Processor processor, int paralellism) {
+ return new ThreadsProcessingItem(processor, paralellism);
+ }
- @Override
- public EntranceProcessingItem createEntrancePi(EntranceProcessor entranceProcessor) {
- return new ThreadsEntranceProcessingItem(entranceProcessor);
- }
+ @Override
+ public EntranceProcessingItem createEntrancePi(EntranceProcessor entranceProcessor) {
+ return new ThreadsEntranceProcessingItem(entranceProcessor);
+ }
- @Override
- public Stream createStream(IProcessingItem sourcePi) {
- return new ThreadsStream(sourcePi);
- }
+ @Override
+ public Stream createStream(IProcessingItem sourcePi) {
+ return new ThreadsStream(sourcePi);
+ }
- @Override
- public Topology createTopology(String topoName) {
- return new ThreadsTopology(topoName);
- }
+ @Override
+ public Topology createTopology(String topoName) {
+ return new ThreadsTopology(topoName);
+ }
}
diff --git a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEngine.java b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEngine.java
index d442572..c266c09 100644
--- a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEngine.java
+++ b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEngine.java
@@ -30,71 +30,72 @@
/**
* Multithreaded engine.
+ *
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsEngine {
-
- private static final List<ExecutorService> threadPool = new ArrayList<ExecutorService>();
-
- /*
- * Create and manage threads
- */
- public static void setNumberOfThreads(int numThreads) {
- if (numThreads < 1)
- throw new IllegalStateException("Number of threads must be a positive integer.");
-
- if (threadPool.size() > numThreads)
- throw new IllegalStateException("You cannot set a numThreads smaller than the current size of the threads pool.");
-
- if (threadPool.size() < numThreads) {
- for (int i=threadPool.size(); i<numThreads; i++) {
- threadPool.add(Executors.newSingleThreadExecutor());
- }
- }
- }
-
- public static int getNumberOfThreads() {
- return threadPool.size();
- }
-
- public static ExecutorService getThreadWithIndex(int index) {
- if (threadPool.size() <= 0 )
- throw new IllegalStateException("Try to get ExecutorService from an empty pool.");
- index %= threadPool.size();
- return threadPool.get(index);
- }
-
- /*
- * Submit topology and start
- */
- private static void submitTopology(Topology topology) {
- ThreadsTopology tTopology = (ThreadsTopology) topology;
- tTopology.run();
- }
-
- public static void submitTopology(Topology topology, int numThreads) {
- ThreadsEngine.setNumberOfThreads(numThreads);
- ThreadsEngine.submitTopology(topology);
- }
-
- /*
- * Stop
- */
- public static void clearThreadPool() {
- for (ExecutorService pool:threadPool) {
- pool.shutdown();
- }
-
- for (ExecutorService pool:threadPool) {
- try {
- pool.awaitTermination(10, TimeUnit.SECONDS);
- } catch (InterruptedException e) {
- e.printStackTrace();
- }
- }
-
- threadPool.clear();
- }
+
+ private static final List<ExecutorService> threadPool = new ArrayList<ExecutorService>();
+
+ /*
+ * Create and manage threads
+ */
+ public static void setNumberOfThreads(int numThreads) {
+ if (numThreads < 1)
+ throw new IllegalStateException("Number of threads must be a positive integer.");
+
+ if (threadPool.size() > numThreads)
+ throw new IllegalStateException("You cannot set a numThreads smaller than the current size of the threads pool.");
+
+ if (threadPool.size() < numThreads) {
+ for (int i = threadPool.size(); i < numThreads; i++) {
+ threadPool.add(Executors.newSingleThreadExecutor());
+ }
+ }
+ }
+
+ public static int getNumberOfThreads() {
+ return threadPool.size();
+ }
+
+ public static ExecutorService getThreadWithIndex(int index) {
+ if (threadPool.size() <= 0)
+ throw new IllegalStateException("Try to get ExecutorService from an empty pool.");
+ index %= threadPool.size();
+ return threadPool.get(index);
+ }
+
+ /*
+ * Submit topology and start
+ */
+ private static void submitTopology(Topology topology) {
+ ThreadsTopology tTopology = (ThreadsTopology) topology;
+ tTopology.run();
+ }
+
+ public static void submitTopology(Topology topology, int numThreads) {
+ ThreadsEngine.setNumberOfThreads(numThreads);
+ ThreadsEngine.submitTopology(topology);
+ }
+
+ /*
+ * Stop
+ */
+ public static void clearThreadPool() {
+ for (ExecutorService pool : threadPool) {
+ pool.shutdown();
+ }
+
+ for (ExecutorService pool : threadPool) {
+ try {
+ pool.awaitTermination(10, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+
+ threadPool.clear();
+ }
}
diff --git a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEntranceProcessingItem.java b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEntranceProcessingItem.java
index 008efb6..470c164 100644
--- a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEntranceProcessingItem.java
+++ b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEntranceProcessingItem.java
@@ -25,16 +25,17 @@
/**
* EntranceProcessingItem for multithreaded engine.
+ *
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsEntranceProcessingItem extends LocalEntranceProcessingItem {
-
- public ThreadsEntranceProcessingItem(EntranceProcessor processor) {
- super(processor);
- }
-
- // The default waiting time when there is no available events is 100ms
- // Override waitForNewEvents() to change it
+
+ public ThreadsEntranceProcessingItem(EntranceProcessor processor) {
+ super(processor);
+ }
+
+ // The default waiting time when there is no available events is 100ms
+ // Override waitForNewEvents() to change it
}
diff --git a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEventRunnable.java b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEventRunnable.java
index 7cb8c18..4dd33db 100644
--- a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEventRunnable.java
+++ b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsEventRunnable.java
@@ -23,39 +23,41 @@
import com.yahoo.labs.samoa.core.ContentEvent;
/**
- * Runnable class where each object corresponds to a ContentEvent and an assigned PI.
- * When a PI receives a ContentEvent, it will create a ThreadsEventRunnable with the received ContentEvent
- * and an assigned workerPI. This runnable is then submitted to a thread queue waiting to be executed.
- * The worker PI will process the received event when the runnable object is executed/run.
+ * Runnable class where each object corresponds to a ContentEvent and an
+ * assigned PI. When a PI receives a ContentEvent, it will create a
+ * ThreadsEventRunnable with the received ContentEvent and an assigned workerPI.
+ * This runnable is then submitted to a thread queue waiting to be executed. The
+ * worker PI will process the received event when the runnable object is
+ * executed/run.
+ *
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsEventRunnable implements Runnable {
- private ThreadsProcessingItemInstance workerPi;
- private ContentEvent event;
-
- public ThreadsEventRunnable(ThreadsProcessingItemInstance workerPi, ContentEvent event) {
- this.workerPi = workerPi;
- this.event = event;
- }
-
- public ThreadsProcessingItemInstance getWorkerProcessingItem() {
- return this.workerPi;
- }
-
- public ContentEvent getContentEvent() {
- return this.event;
- }
-
- @Override
- public void run() {
- try {
- workerPi.processEvent(event);
- }
- catch (Exception e) {
- e.printStackTrace();
- }
- }
+ private ThreadsProcessingItemInstance workerPi;
+ private ContentEvent event;
+
+ public ThreadsEventRunnable(ThreadsProcessingItemInstance workerPi, ContentEvent event) {
+ this.workerPi = workerPi;
+ this.event = event;
+ }
+
+ public ThreadsProcessingItemInstance getWorkerProcessingItem() {
+ return this.workerPi;
+ }
+
+ public ContentEvent getContentEvent() {
+ return this.event;
+ }
+
+ @Override
+ public void run() {
+ try {
+ workerPi.processEvent(event);
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
}
diff --git a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItem.java b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItem.java
index 5eb6174..1b83a05 100644
--- a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItem.java
+++ b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItem.java
@@ -33,69 +33,71 @@
/**
* ProcessingItem for multithreaded engine.
+ *
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsProcessingItem extends AbstractProcessingItem {
- // Replicas of the ProcessingItem.
- // When ProcessingItem receives an event, it assigns one
- // of these replicas to process the event.
- private List<ThreadsProcessingItemInstance> piInstances;
-
- // Each replica of ProcessingItem is assigned to one of the
- // available threads in a round-robin fashion, i.e.: each
- // replica is associated with the index of a thread.
- // Each ProcessingItem has a random offset variable so that
- // the allocation of PI replicas to threads are spread evenly
- // among all threads.
- private int offset;
-
- /*
- * Constructor
- */
- public ThreadsProcessingItem(Processor processor, int parallelismHint) {
- super(processor, parallelismHint);
- this.offset = (int) (Math.random()*ThreadsEngine.getNumberOfThreads());
- }
-
- public List<ThreadsProcessingItemInstance> getProcessingItemInstances() {
- return this.piInstances;
- }
+ // Replicas of the ProcessingItem.
+ // When ProcessingItem receives an event, it assigns one
+ // of these replicas to process the event.
+ private List<ThreadsProcessingItemInstance> piInstances;
- /*
- * Connects to streams
- */
- @Override
- protected ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme) {
- StreamDestination destination = new StreamDestination(this, this.getParallelism(), scheme);
- ((ThreadsStream) inputStream).addDestination(destination);
- return this;
- }
+ // Each replica of ProcessingItem is assigned to one of the
+ // available threads in a round-robin fashion, i.e.: each
+ // replica is associated with the index of a thread.
+ // Each ProcessingItem has a random offset variable so that
+ // the allocation of PI replicas to threads are spread evenly
+ // among all threads.
+ private int offset;
- /*
- * Process the received event.
- */
- public void processEvent(ContentEvent event, int counter) {
- if (this.piInstances == null || this.piInstances.size() < this.getParallelism())
- throw new IllegalStateException("ThreadsWorkerProcessingItem(s) need to be setup before process any event (i.e. in ThreadsTopology.start()).");
-
- ThreadsProcessingItemInstance piInstance = this.piInstances.get(counter);
- ThreadsEventRunnable runnable = new ThreadsEventRunnable(piInstance, event);
- ThreadsEngine.getThreadWithIndex(piInstance.getThreadIndex()).submit(runnable);
- }
-
- /*
- * Setup the replicas of this PI.
- * This should be called after the topology is set up (all Processors and PIs are
- * setup and connected to the respective streams) and before events are sent.
- */
- public void setupInstances() {
- this.piInstances = new ArrayList<ThreadsProcessingItemInstance>(this.getParallelism());
- for (int i=0; i<this.getParallelism(); i++) {
- Processor newProcessor = this.getProcessor().newProcessor(this.getProcessor());
- newProcessor.onCreate(i + 1);
- this.piInstances.add(new ThreadsProcessingItemInstance(newProcessor, this.offset + i));
- }
- }
+ /*
+ * Constructor
+ */
+ public ThreadsProcessingItem(Processor processor, int parallelismHint) {
+ super(processor, parallelismHint);
+ this.offset = (int) (Math.random() * ThreadsEngine.getNumberOfThreads());
+ }
+
+ public List<ThreadsProcessingItemInstance> getProcessingItemInstances() {
+ return this.piInstances;
+ }
+
+ /*
+ * Connects to streams
+ */
+ @Override
+ protected ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme) {
+ StreamDestination destination = new StreamDestination(this, this.getParallelism(), scheme);
+ ((ThreadsStream) inputStream).addDestination(destination);
+ return this;
+ }
+
+ /*
+ * Process the received event.
+ */
+ public void processEvent(ContentEvent event, int counter) {
+ if (this.piInstances == null || this.piInstances.size() < this.getParallelism())
+ throw new IllegalStateException(
+ "ThreadsWorkerProcessingItem(s) need to be setup before process any event (i.e. in ThreadsTopology.start()).");
+
+ ThreadsProcessingItemInstance piInstance = this.piInstances.get(counter);
+ ThreadsEventRunnable runnable = new ThreadsEventRunnable(piInstance, event);
+ ThreadsEngine.getThreadWithIndex(piInstance.getThreadIndex()).submit(runnable);
+ }
+
+ /*
+ * Setup the replicas of this PI. This should be called after the topology is
+ * set up (all Processors and PIs are setup and connected to the respective
+ * streams) and before events are sent.
+ */
+ public void setupInstances() {
+ this.piInstances = new ArrayList<ThreadsProcessingItemInstance>(this.getParallelism());
+ for (int i = 0; i < this.getParallelism(); i++) {
+ Processor newProcessor = this.getProcessor().newProcessor(this.getProcessor());
+ newProcessor.onCreate(i + 1);
+ this.piInstances.add(new ThreadsProcessingItemInstance(newProcessor, this.offset + i));
+ }
+ }
}
diff --git a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemInstance.java b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemInstance.java
index 9a400d1..73052ea 100644
--- a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemInstance.java
+++ b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemInstance.java
@@ -24,31 +24,32 @@
import com.yahoo.labs.samoa.core.Processor;
/**
- * Lightweight replicas of ThreadProcessingItem.
- * ThreadsProcessingItem manages a list of these objects and
- * assigns each incoming message to be processed by one of them.
+ * Lightweight replicas of ThreadProcessingItem. ThreadsProcessingItem manages a
+ * list of these objects and assigns each incoming message to be processed by
+ * one of them.
+ *
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsProcessingItemInstance {
- private Processor processor;
- private int threadIndex;
-
- public ThreadsProcessingItemInstance(Processor processor, int threadIndex) {
- this.processor = processor;
- this.threadIndex = threadIndex;
- }
-
- public int getThreadIndex() {
- return this.threadIndex;
- }
-
- public Processor getProcessor() {
- return this.processor;
- }
+ private Processor processor;
+ private int threadIndex;
- public void processEvent(ContentEvent event) {
- this.processor.process(event);
- }
+ public ThreadsProcessingItemInstance(Processor processor, int threadIndex) {
+ this.processor = processor;
+ this.threadIndex = threadIndex;
+ }
+
+ public int getThreadIndex() {
+ return this.threadIndex;
+ }
+
+ public Processor getProcessor() {
+ return this.processor;
+ }
+
+ public void processEvent(ContentEvent event) {
+ this.processor.process(event);
+ }
}
diff --git a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsStream.java b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsStream.java
index 5aa86f7..2c02df7 100644
--- a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsStream.java
+++ b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsStream.java
@@ -32,75 +32,78 @@
/**
* Stream for multithreaded engine.
+ *
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsStream extends AbstractStream {
-
- private List<StreamDestination> destinations;
- private int counter = 0;
- private int maxCounter = 1;
-
- public ThreadsStream(IProcessingItem sourcePi) {
- destinations = new LinkedList<StreamDestination>();
- }
-
- public void addDestination(StreamDestination destination) {
- destinations.add(destination);
- maxCounter *= destination.getParallelism();
- }
-
- public List<StreamDestination> getDestinations() {
- return this.destinations;
- }
-
- private int getNextCounter() {
- if (maxCounter > 0 && counter >= maxCounter) counter = 0;
- this.counter++;
- return this.counter;
- }
- @Override
- public synchronized void put(ContentEvent event) {
- this.put(event, this.getNextCounter());
- }
-
- private void put(ContentEvent event, int counter) {
- ThreadsProcessingItem pi;
- int parallelism;
- for (StreamDestination destination:destinations) {
- pi = (ThreadsProcessingItem) destination.getProcessingItem();
- parallelism = destination.getParallelism();
- switch (destination.getPartitioningScheme()) {
- case SHUFFLE:
- pi.processEvent(event, counter%parallelism);
- break;
- case GROUP_BY_KEY:
- pi.processEvent(event, getPIIndexForKey(event.getKey(), parallelism));
- break;
- case BROADCAST:
- for (int p = 0; p < parallelism; p++) {
- pi.processEvent(event, p);
- }
- break;
- }
+ private List<StreamDestination> destinations;
+ private int counter = 0;
+ private int maxCounter = 1;
+
+ public ThreadsStream(IProcessingItem sourcePi) {
+ destinations = new LinkedList<StreamDestination>();
+ }
+
+ public void addDestination(StreamDestination destination) {
+ destinations.add(destination);
+ maxCounter *= destination.getParallelism();
+ }
+
+ public List<StreamDestination> getDestinations() {
+ return this.destinations;
+ }
+
+ private int getNextCounter() {
+ if (maxCounter > 0 && counter >= maxCounter)
+ counter = 0;
+ this.counter++;
+ return this.counter;
+ }
+
+ @Override
+ public synchronized void put(ContentEvent event) {
+ this.put(event, this.getNextCounter());
+ }
+
+ private void put(ContentEvent event, int counter) {
+ ThreadsProcessingItem pi;
+ int parallelism;
+ for (StreamDestination destination : destinations) {
+ pi = (ThreadsProcessingItem) destination.getProcessingItem();
+ parallelism = destination.getParallelism();
+ switch (destination.getPartitioningScheme()) {
+ case SHUFFLE:
+ pi.processEvent(event, counter % parallelism);
+ break;
+ case GROUP_BY_KEY:
+ pi.processEvent(event, getPIIndexForKey(event.getKey(), parallelism));
+ break;
+ case BROADCAST:
+ for (int p = 0; p < parallelism; p++) {
+ pi.processEvent(event, p);
}
+ break;
+ }
}
-
- private static int getPIIndexForKey(String key, int parallelism) {
- // If key is null, return a default index: 0
- if (key == null) return 0;
-
- // HashCodeBuilder object does not have reset() method
- // So all objects that get appended will be included in the
- // computation of the hashcode.
- // To avoid initialize a HashCodeBuilder for each event,
- // here I use the static method with reflection on the event's key
- int index = HashCodeBuilder.reflectionHashCode(key, true) % parallelism;
- if (index < 0) {
- index += parallelism;
- }
- return index;
- }
+ }
+
+ private static int getPIIndexForKey(String key, int parallelism) {
+ // If key is null, return a default index: 0
+ if (key == null)
+ return 0;
+
+ // HashCodeBuilder object does not have reset() method
+ // So all objects that get appended will be included in the
+ // computation of the hashcode.
+ // To avoid initialize a HashCodeBuilder for each event,
+ // here I use the static method with reflection on the event's key
+ int index = HashCodeBuilder.reflectionHashCode(key, true) % parallelism;
+ if (index < 0) {
+ index += parallelism;
+ }
+ return index;
+ }
}
diff --git a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsTopology.java b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsTopology.java
index fc9f885..a6bad2b 100644
--- a/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsTopology.java
+++ b/samoa-threads/src/main/java/com/yahoo/labs/samoa/topology/impl/ThreadsTopology.java
@@ -25,38 +25,42 @@
/**
* Topology for multithreaded engine.
+ *
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsTopology extends AbstractTopology {
- ThreadsTopology(String name) {
- super(name);
- }
+ ThreadsTopology(String name) {
+ super(name);
+ }
- public void run() {
- if (this.getEntranceProcessingItems() == null)
- throw new IllegalStateException("You need to set entrance PI before running the topology.");
- if (this.getEntranceProcessingItems().size() != 1)
- throw new IllegalStateException("ThreadsTopology supports 1 entrance PI only. Number of entrance PIs is "+this.getEntranceProcessingItems().size());
-
- this.setupProcessingItemInstances();
- ThreadsEntranceProcessingItem entrancePi = (ThreadsEntranceProcessingItem) this.getEntranceProcessingItems().toArray()[0];
- if (entrancePi == null)
- throw new IllegalStateException("You need to set entrance PI before running the topology.");
- entrancePi.getProcessor().onCreate(0); // id=0 as it is not used in simple mode
- entrancePi.startSendingEvents();
+ public void run() {
+ if (this.getEntranceProcessingItems() == null)
+ throw new IllegalStateException("You need to set entrance PI before running the topology.");
+ if (this.getEntranceProcessingItems().size() != 1)
+ throw new IllegalStateException("ThreadsTopology supports 1 entrance PI only. Number of entrance PIs is "
+ + this.getEntranceProcessingItems().size());
+
+ this.setupProcessingItemInstances();
+ ThreadsEntranceProcessingItem entrancePi = (ThreadsEntranceProcessingItem) this.getEntranceProcessingItems()
+ .toArray()[0];
+ if (entrancePi == null)
+ throw new IllegalStateException("You need to set entrance PI before running the topology.");
+ entrancePi.getProcessor().onCreate(0); // id=0 as it is not used in simple
+ // mode
+ entrancePi.startSendingEvents();
+ }
+
+ /*
+ * Tell all the ThreadsProcessingItems to create & init their replicas
+ * (ThreadsProcessingItemInstance)
+ */
+ private void setupProcessingItemInstances() {
+ for (IProcessingItem pi : this.getProcessingItems()) {
+ if (pi instanceof ThreadsProcessingItem) {
+ ThreadsProcessingItem tpi = (ThreadsProcessingItem) pi;
+ tpi.setupInstances();
+ }
}
-
- /*
- * Tell all the ThreadsProcessingItems to create & init their
- * replicas (ThreadsProcessingItemInstance)
- */
- private void setupProcessingItemInstances() {
- for (IProcessingItem pi:this.getProcessingItems()) {
- if (pi instanceof ThreadsProcessingItem) {
- ThreadsProcessingItem tpi = (ThreadsProcessingItem) pi;
- tpi.setupInstances();
- }
- }
- }
+ }
}
diff --git a/samoa-threads/src/test/java/com/yahoo/labs/samoa/AlgosTest.java b/samoa-threads/src/test/java/com/yahoo/labs/samoa/AlgosTest.java
index 5979d46..0b9b8a2 100644
--- a/samoa-threads/src/test/java/com/yahoo/labs/samoa/AlgosTest.java
+++ b/samoa-threads/src/test/java/com/yahoo/labs/samoa/AlgosTest.java
@@ -24,45 +24,45 @@
public class AlgosTest {
- @Test(timeout = 60000)
- public void testVHTWithThreads() throws Exception {
+ @Test(timeout = 60000)
+ public void testVHTWithThreads() throws Exception {
- TestParams vhtConfig = new TestParams.Builder()
- .inputInstances(200_000)
- .samplingSize(20_000)
- .evaluationInstances(200_000)
- .classifiedInstances(200_000)
- .classificationsCorrect(55f)
- .kappaStat(-0.1f)
- .kappaTempStat(-0.1f)
- .cliStringTemplate(TestParams.Templates.PREQEVAL_VHT_RANDOMTREE + " -t 2")
- .resultFilePollTimeout(10)
- .prePollWait(10)
- .taskClassName(LocalThreadsDoTask.class.getName())
- .build();
- TestUtils.test(vhtConfig);
+ TestParams vhtConfig = new TestParams.Builder()
+ .inputInstances(200_000)
+ .samplingSize(20_000)
+ .evaluationInstances(200_000)
+ .classifiedInstances(200_000)
+ .classificationsCorrect(55f)
+ .kappaStat(-0.1f)
+ .kappaTempStat(-0.1f)
+ .cliStringTemplate(TestParams.Templates.PREQEVAL_VHT_RANDOMTREE + " -t 2")
+ .resultFilePollTimeout(10)
+ .prePollWait(10)
+ .taskClassName(LocalThreadsDoTask.class.getName())
+ .build();
+ TestUtils.test(vhtConfig);
- }
+ }
- @Test(timeout = 180000)
- public void testBaggingWithThreads() throws Exception {
- TestParams baggingConfig = new TestParams.Builder()
- .inputInstances(100_000)
- .samplingSize(10_000)
- .inputDelayMicroSec(100) // prevents saturating the system due to unbounded queues
- .evaluationInstances(90_000)
- .classifiedInstances(105_000)
- .classificationsCorrect(55f)
- .kappaStat(0f)
- .kappaTempStat(0f)
- .cliStringTemplate(TestParams.Templates.PREQEVAL_BAGGING_RANDOMTREE + " -t 2")
- .prePollWait(10)
- .resultFilePollTimeout(30)
- .taskClassName(LocalThreadsDoTask.class.getName())
- .build();
- TestUtils.test(baggingConfig);
+ @Test(timeout = 180000)
+ public void testBaggingWithThreads() throws Exception {
+ TestParams baggingConfig = new TestParams.Builder()
+ .inputInstances(100_000)
+ .samplingSize(10_000)
+ .inputDelayMicroSec(100) // prevents saturating the system due to
+ // unbounded queues
+ .evaluationInstances(90_000)
+ .classifiedInstances(105_000)
+ .classificationsCorrect(55f)
+ .kappaStat(0f)
+ .kappaTempStat(0f)
+ .cliStringTemplate(TestParams.Templates.PREQEVAL_BAGGING_RANDOMTREE + " -t 2")
+ .prePollWait(10)
+ .resultFilePollTimeout(30)
+ .taskClassName(LocalThreadsDoTask.class.getName())
+ .build();
+ TestUtils.test(baggingConfig);
- }
-
+ }
}
diff --git a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsComponentFactoryTest.java b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsComponentFactoryTest.java
index eee8639..12b5b90 100644
--- a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsComponentFactoryTest.java
+++ b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsComponentFactoryTest.java
@@ -37,78 +37,81 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsComponentFactoryTest {
- @Tested private ThreadsComponentFactory factory;
- @Mocked private Processor processor, processorReplica;
- @Mocked private EntranceProcessor entranceProcessor;
-
- private final int parallelism = 3;
- private final String topoName = "TestTopology";
-
+ @Tested
+ private ThreadsComponentFactory factory;
+ @Mocked
+ private Processor processor, processorReplica;
+ @Mocked
+ private EntranceProcessor entranceProcessor;
- @Before
- public void setUp() throws Exception {
- factory = new ThreadsComponentFactory();
- }
+ private final int parallelism = 3;
+ private final String topoName = "TestTopology";
- @Test
- public void testCreatePiNoParallelism() {
- new NonStrictExpectations() {
- {
- processor.newProcessor(processor);
- result=processorReplica;
- }
- };
- ProcessingItem pi = factory.createPi(processor);
- assertNotNull("ProcessingItem created is null.",pi);
- assertEquals("ProcessingItem created is not a ThreadsProcessingItem.",ThreadsProcessingItem.class,pi.getClass());
- assertEquals("Parallelism of PI is not 1",1,pi.getParallelism(),0);
- }
-
- @Test
- public void testCreatePiWithParallelism() {
- new NonStrictExpectations() {
- {
- processor.newProcessor(processor);
- result=processorReplica;
- }
- };
- ProcessingItem pi = factory.createPi(processor,parallelism);
- assertNotNull("ProcessingItem created is null.",pi);
- assertEquals("ProcessingItem created is not a ThreadsProcessingItem.",ThreadsProcessingItem.class,pi.getClass());
- assertEquals("Parallelism of PI is not ",parallelism,pi.getParallelism(),0);
- }
-
- @Test
- public void testCreateStream() {
- new NonStrictExpectations() {
- {
- processor.newProcessor(processor);
- result=processorReplica;
- }
- };
- ProcessingItem pi = factory.createPi(processor);
-
- Stream stream = factory.createStream(pi);
- assertNotNull("Stream created is null",stream);
- assertEquals("Stream created is not a ThreadsStream.",ThreadsStream.class,stream.getClass());
- }
-
- @Test
- public void testCreateTopology() {
- Topology topology = factory.createTopology(topoName);
- assertNotNull("Topology created is null.",topology);
- assertEquals("Topology created is not a ThreadsTopology.",ThreadsTopology.class,topology.getClass());
- }
-
- @Test
- public void testCreateEntrancePi() {
- EntranceProcessingItem entrancePi = factory.createEntrancePi(entranceProcessor);
- assertNotNull("EntranceProcessingItem created is null.",entrancePi);
- assertEquals("EntranceProcessingItem created is not a ThreadsEntranceProcessingItem.",ThreadsEntranceProcessingItem.class,entrancePi.getClass());
- assertSame("EntranceProcessor is not set correctly.",entranceProcessor, entrancePi.getProcessor());
- }
+ @Before
+ public void setUp() throws Exception {
+ factory = new ThreadsComponentFactory();
+ }
+
+ @Test
+ public void testCreatePiNoParallelism() {
+ new NonStrictExpectations() {
+ {
+ processor.newProcessor(processor);
+ result = processorReplica;
+ }
+ };
+ ProcessingItem pi = factory.createPi(processor);
+ assertNotNull("ProcessingItem created is null.", pi);
+ assertEquals("ProcessingItem created is not a ThreadsProcessingItem.", ThreadsProcessingItem.class, pi.getClass());
+ assertEquals("Parallelism of PI is not 1", 1, pi.getParallelism(), 0);
+ }
+
+ @Test
+ public void testCreatePiWithParallelism() {
+ new NonStrictExpectations() {
+ {
+ processor.newProcessor(processor);
+ result = processorReplica;
+ }
+ };
+ ProcessingItem pi = factory.createPi(processor, parallelism);
+ assertNotNull("ProcessingItem created is null.", pi);
+ assertEquals("ProcessingItem created is not a ThreadsProcessingItem.", ThreadsProcessingItem.class, pi.getClass());
+ assertEquals("Parallelism of PI is not ", parallelism, pi.getParallelism(), 0);
+ }
+
+ @Test
+ public void testCreateStream() {
+ new NonStrictExpectations() {
+ {
+ processor.newProcessor(processor);
+ result = processorReplica;
+ }
+ };
+ ProcessingItem pi = factory.createPi(processor);
+
+ Stream stream = factory.createStream(pi);
+ assertNotNull("Stream created is null", stream);
+ assertEquals("Stream created is not a ThreadsStream.", ThreadsStream.class, stream.getClass());
+ }
+
+ @Test
+ public void testCreateTopology() {
+ Topology topology = factory.createTopology(topoName);
+ assertNotNull("Topology created is null.", topology);
+ assertEquals("Topology created is not a ThreadsTopology.", ThreadsTopology.class, topology.getClass());
+ }
+
+ @Test
+ public void testCreateEntrancePi() {
+ EntranceProcessingItem entrancePi = factory.createEntrancePi(entranceProcessor);
+ assertNotNull("EntranceProcessingItem created is null.", entrancePi);
+ assertEquals("EntranceProcessingItem created is not a ThreadsEntranceProcessingItem.",
+ ThreadsEntranceProcessingItem.class, entrancePi.getClass());
+ assertSame("EntranceProcessor is not set correctly.", entranceProcessor, entrancePi.getProcessor());
+ }
}
diff --git a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEngineTest.java b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEngineTest.java
index cdb8949..c8a3c3d 100644
--- a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEngineTest.java
+++ b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEngineTest.java
@@ -29,101 +29,105 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsEngineTest {
- @Mocked ThreadsTopology topology;
-
- private final int numThreads = 4;
- private final int numThreadsSmaller = 3;
- private final int numThreadsLarger = 5;
+ @Mocked
+ ThreadsTopology topology;
- @After
- public void cleanup() {
- ThreadsEngine.clearThreadPool();
- }
-
- @Test
- public void testSetNumberOfThreadsSimple() {
- ThreadsEngine.setNumberOfThreads(numThreads);
- assertEquals("Number of threads is not set correctly.", numThreads,
- ThreadsEngine.getNumberOfThreads(),0);
- }
-
- @Test
- public void testSetNumberOfThreadsRepeat() {
- ThreadsEngine.setNumberOfThreads(numThreads);
- ThreadsEngine.setNumberOfThreads(numThreads);
- assertEquals("Number of threads is not set correctly.", numThreads,
- ThreadsEngine.getNumberOfThreads(),0);
- }
-
- @Test
- public void testSetNumberOfThreadsIncrease() {
- ThreadsEngine.setNumberOfThreads(numThreads);
- ThreadsEngine.setNumberOfThreads(numThreadsLarger);
- assertEquals("Number of threads is not set correctly.", numThreadsLarger,
- ThreadsEngine.getNumberOfThreads(),0);
- }
-
- @Test(expected=IllegalStateException.class)
- public void testSetNumberOfThreadsDecrease() {
- ThreadsEngine.setNumberOfThreads(numThreads);
- ThreadsEngine.setNumberOfThreads(numThreadsSmaller);
- // Exception expected
- }
-
- @Test(expected=IllegalStateException.class)
- public void testSetNumberOfThreadsNegative() {
- ThreadsEngine.setNumberOfThreads(-1);
- // Exception expected
- }
-
- @Test(expected=IllegalStateException.class)
- public void testSetNumberOfThreadsZero() {
- ThreadsEngine.setNumberOfThreads(0);
- // Exception expected
- }
-
- @Test
- public void testClearThreadPool() {
- ThreadsEngine.setNumberOfThreads(numThreads);
- ThreadsEngine.clearThreadPool();
- assertEquals("ThreadsEngine was not shutdown properly.", 0, ThreadsEngine.getNumberOfThreads());
- }
+ private final int numThreads = 4;
+ private final int numThreadsSmaller = 3;
+ private final int numThreadsLarger = 5;
- @Test
- public void testGetThreadWithIndexWithinPoolSize() {
- ThreadsEngine.setNumberOfThreads(numThreads);
- for (int i=0; i<numThreads; i++) {
- assertNotNull("ExecutorService is not initialized correctly.", ThreadsEngine.getThreadWithIndex(i));
- }
- }
-
- @Test
- public void testGetThreadWithIndexOutOfPoolSize() {
- ThreadsEngine.setNumberOfThreads(numThreads);
- for (int i=0; i<numThreads+3; i++) {
- assertNotNull("ExecutorService is not initialized correctly.", ThreadsEngine.getThreadWithIndex(i));
- }
- }
-
- @Test(expected=IllegalStateException.class)
- public void testGetThreadWithIndexFromEmptyPool() {
- for (int i=0; i<numThreads; i++) {
- ThreadsEngine.getThreadWithIndex(i);
- }
- }
+ @After
+ public void cleanup() {
+ ThreadsEngine.clearThreadPool();
+ }
- @Test
- public void testSubmitTopology() {
- ThreadsEngine.submitTopology(topology, numThreads);
- new Verifications() {{
- topology.run(); times=1;
- }};
- assertEquals("Number of threads is not set correctly.", numThreads,
- ThreadsEngine.getNumberOfThreads(),0);
- }
+ @Test
+ public void testSetNumberOfThreadsSimple() {
+ ThreadsEngine.setNumberOfThreads(numThreads);
+ assertEquals("Number of threads is not set correctly.", numThreads,
+ ThreadsEngine.getNumberOfThreads(), 0);
+ }
+
+ @Test
+ public void testSetNumberOfThreadsRepeat() {
+ ThreadsEngine.setNumberOfThreads(numThreads);
+ ThreadsEngine.setNumberOfThreads(numThreads);
+ assertEquals("Number of threads is not set correctly.", numThreads,
+ ThreadsEngine.getNumberOfThreads(), 0);
+ }
+
+ @Test
+ public void testSetNumberOfThreadsIncrease() {
+ ThreadsEngine.setNumberOfThreads(numThreads);
+ ThreadsEngine.setNumberOfThreads(numThreadsLarger);
+ assertEquals("Number of threads is not set correctly.", numThreadsLarger,
+ ThreadsEngine.getNumberOfThreads(), 0);
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testSetNumberOfThreadsDecrease() {
+ ThreadsEngine.setNumberOfThreads(numThreads);
+ ThreadsEngine.setNumberOfThreads(numThreadsSmaller);
+ // Exception expected
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testSetNumberOfThreadsNegative() {
+ ThreadsEngine.setNumberOfThreads(-1);
+ // Exception expected
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testSetNumberOfThreadsZero() {
+ ThreadsEngine.setNumberOfThreads(0);
+ // Exception expected
+ }
+
+ @Test
+ public void testClearThreadPool() {
+ ThreadsEngine.setNumberOfThreads(numThreads);
+ ThreadsEngine.clearThreadPool();
+ assertEquals("ThreadsEngine was not shutdown properly.", 0, ThreadsEngine.getNumberOfThreads());
+ }
+
+ @Test
+ public void testGetThreadWithIndexWithinPoolSize() {
+ ThreadsEngine.setNumberOfThreads(numThreads);
+ for (int i = 0; i < numThreads; i++) {
+ assertNotNull("ExecutorService is not initialized correctly.", ThreadsEngine.getThreadWithIndex(i));
+ }
+ }
+
+ @Test
+ public void testGetThreadWithIndexOutOfPoolSize() {
+ ThreadsEngine.setNumberOfThreads(numThreads);
+ for (int i = 0; i < numThreads + 3; i++) {
+ assertNotNull("ExecutorService is not initialized correctly.", ThreadsEngine.getThreadWithIndex(i));
+ }
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testGetThreadWithIndexFromEmptyPool() {
+ for (int i = 0; i < numThreads; i++) {
+ ThreadsEngine.getThreadWithIndex(i);
+ }
+ }
+
+ @Test
+ public void testSubmitTopology() {
+ ThreadsEngine.submitTopology(topology, numThreads);
+ new Verifications() {
+ {
+ topology.run();
+ times = 1;
+ }
+ };
+ assertEquals("Number of threads is not set correctly.", numThreads,
+ ThreadsEngine.getNumberOfThreads(), 0);
+ }
}
diff --git a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEntranceProcessingItemTest.java b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEntranceProcessingItemTest.java
index 2dab489..db2a3fb 100644
--- a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEntranceProcessingItemTest.java
+++ b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEntranceProcessingItemTest.java
@@ -34,100 +34,118 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsEntranceProcessingItemTest {
- @Tested private ThreadsEntranceProcessingItem entrancePi;
-
- @Mocked private EntranceProcessor entranceProcessor;
- @Mocked private Stream outputStream, anotherStream;
- @Mocked private ContentEvent event;
-
- @Mocked private Thread unused;
-
- /**
- * @throws java.lang.Exception
- */
- @Before
- public void setUp() throws Exception {
- entrancePi = new ThreadsEntranceProcessingItem(entranceProcessor);
- }
+ @Tested
+ private ThreadsEntranceProcessingItem entrancePi;
- @Test
- public void testContructor() {
- assertSame("EntranceProcessor is not set correctly.",entranceProcessor,entrancePi.getProcessor());
- }
-
- @Test
- public void testSetOutputStream() {
- entrancePi.setOutputStream(outputStream);
- assertSame("OutoutStream is not set correctly.",outputStream,entrancePi.getOutputStream());
- }
-
- @Test
- public void testSetOutputStreamRepeate() {
- entrancePi.setOutputStream(outputStream);
- entrancePi.setOutputStream(outputStream);
- assertSame("OutoutStream is not set correctly.",outputStream,entrancePi.getOutputStream());
- }
-
- @Test(expected=IllegalStateException.class)
- public void testSetOutputStreamError() {
- entrancePi.setOutputStream(outputStream);
- entrancePi.setOutputStream(anotherStream);
- }
-
- @Test
- public void testStartSendingEvents() {
- entrancePi.setOutputStream(outputStream);
- new StrictExpectations() {
- {
- for (int i=0; i<1; i++) {
- entranceProcessor.isFinished(); result=false;
- entranceProcessor.hasNext(); result=false;
- }
-
- for (int i=0; i<5; i++) {
- entranceProcessor.isFinished(); result=false;
- entranceProcessor.hasNext(); result=true;
- entranceProcessor.nextEvent(); result=event;
- outputStream.put(event);
- }
-
- for (int i=0; i<2; i++) {
- entranceProcessor.isFinished(); result=false;
- entranceProcessor.hasNext(); result=false;
- }
-
- for (int i=0; i<5; i++) {
- entranceProcessor.isFinished(); result=false;
- entranceProcessor.hasNext(); result=true;
- entranceProcessor.nextEvent(); result=event;
- outputStream.put(event);
- }
+ @Mocked
+ private EntranceProcessor entranceProcessor;
+ @Mocked
+ private Stream outputStream, anotherStream;
+ @Mocked
+ private ContentEvent event;
- entranceProcessor.isFinished(); result=true; times=1;
- entranceProcessor.hasNext(); times=0;
+ @Mocked
+ private Thread unused;
+ /**
+ * @throws java.lang.Exception
+ */
+ @Before
+ public void setUp() throws Exception {
+ entrancePi = new ThreadsEntranceProcessingItem(entranceProcessor);
+ }
- }
- };
- entrancePi.startSendingEvents();
- new Verifications() {
- {
- try {
- Thread.sleep(anyInt); times=3;
- } catch (InterruptedException e) {
-
- }
- }
- };
- }
-
- @Test(expected=IllegalStateException.class)
- public void testStartSendingEventsError() {
- entrancePi.startSendingEvents();
- }
+ @Test
+ public void testContructor() {
+ assertSame("EntranceProcessor is not set correctly.", entranceProcessor, entrancePi.getProcessor());
+ }
+
+ @Test
+ public void testSetOutputStream() {
+ entrancePi.setOutputStream(outputStream);
+ assertSame("OutoutStream is not set correctly.", outputStream, entrancePi.getOutputStream());
+ }
+
+ @Test
+ public void testSetOutputStreamRepeate() {
+ entrancePi.setOutputStream(outputStream);
+ entrancePi.setOutputStream(outputStream);
+ assertSame("OutoutStream is not set correctly.", outputStream, entrancePi.getOutputStream());
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testSetOutputStreamError() {
+ entrancePi.setOutputStream(outputStream);
+ entrancePi.setOutputStream(anotherStream);
+ }
+
+ @Test
+ public void testStartSendingEvents() {
+ entrancePi.setOutputStream(outputStream);
+ new StrictExpectations() {
+ {
+ for (int i = 0; i < 1; i++) {
+ entranceProcessor.isFinished();
+ result = false;
+ entranceProcessor.hasNext();
+ result = false;
+ }
+
+ for (int i = 0; i < 5; i++) {
+ entranceProcessor.isFinished();
+ result = false;
+ entranceProcessor.hasNext();
+ result = true;
+ entranceProcessor.nextEvent();
+ result = event;
+ outputStream.put(event);
+ }
+
+ for (int i = 0; i < 2; i++) {
+ entranceProcessor.isFinished();
+ result = false;
+ entranceProcessor.hasNext();
+ result = false;
+ }
+
+ for (int i = 0; i < 5; i++) {
+ entranceProcessor.isFinished();
+ result = false;
+ entranceProcessor.hasNext();
+ result = true;
+ entranceProcessor.nextEvent();
+ result = event;
+ outputStream.put(event);
+ }
+
+ entranceProcessor.isFinished();
+ result = true;
+ times = 1;
+ entranceProcessor.hasNext();
+ times = 0;
+
+ }
+ };
+ entrancePi.startSendingEvents();
+ new Verifications() {
+ {
+ try {
+ Thread.sleep(anyInt);
+ times = 3;
+ } catch (InterruptedException e) {
+
+ }
+ }
+ };
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testStartSendingEventsError() {
+ entrancePi.startSendingEvents();
+ }
}
diff --git a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEventRunnableTest.java b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEventRunnableTest.java
index f744162..1e70d10 100644
--- a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEventRunnableTest.java
+++ b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsEventRunnableTest.java
@@ -31,37 +31,41 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsEventRunnableTest {
- @Tested private ThreadsEventRunnable task;
-
- @Mocked private ThreadsProcessingItemInstance piInstance;
- @Mocked private ContentEvent event;
-
- /**
- * @throws java.lang.Exception
- */
- @Before
- public void setUp() throws Exception {
- task = new ThreadsEventRunnable(piInstance, event);
- }
+ @Tested
+ private ThreadsEventRunnable task;
- @Test
- public void testConstructor() {
- assertSame("WorkerProcessingItem is not set correctly.",piInstance,task.getWorkerProcessingItem());
- assertSame("ContentEvent is not set correctly.",event,task.getContentEvent());
- }
-
- @Test
- public void testRun() {
- task.run();
- new Verifications () {
- {
- piInstance.processEvent(event); times=1;
- }
- };
- }
+ @Mocked
+ private ThreadsProcessingItemInstance piInstance;
+ @Mocked
+ private ContentEvent event;
+
+ /**
+ * @throws java.lang.Exception
+ */
+ @Before
+ public void setUp() throws Exception {
+ task = new ThreadsEventRunnable(piInstance, event);
+ }
+
+ @Test
+ public void testConstructor() {
+ assertSame("WorkerProcessingItem is not set correctly.", piInstance, task.getWorkerProcessingItem());
+ assertSame("ContentEvent is not set correctly.", event, task.getContentEvent());
+ }
+
+ @Test
+ public void testRun() {
+ task.run();
+ new Verifications() {
+ {
+ piInstance.processEvent(event);
+ times = 1;
+ }
+ };
+ }
}
diff --git a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemInstanceTest.java b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemInstanceTest.java
index 33af044..d4f78b0 100644
--- a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemInstanceTest.java
+++ b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemInstanceTest.java
@@ -32,36 +32,40 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsProcessingItemInstanceTest {
- @Tested private ThreadsProcessingItemInstance piInstance;
-
- @Mocked private Processor processor;
- @Mocked private ContentEvent event;
-
- private final int threadIndex = 2;
-
- @Before
- public void setUp() throws Exception {
- piInstance = new ThreadsProcessingItemInstance(processor, threadIndex);
- }
+ @Tested
+ private ThreadsProcessingItemInstance piInstance;
- @Test
- public void testConstructor() {
- assertSame("Processor is not set correctly.", processor, piInstance.getProcessor());
- assertEquals("Thread index is not set correctly.", threadIndex, piInstance.getThreadIndex(),0);
- }
-
- @Test
- public void testProcessEvent() {
- piInstance.processEvent(event);
- new Verifications() {
- {
- processor.process(event); times=1;
- }
- };
- }
+ @Mocked
+ private Processor processor;
+ @Mocked
+ private ContentEvent event;
+
+ private final int threadIndex = 2;
+
+ @Before
+ public void setUp() throws Exception {
+ piInstance = new ThreadsProcessingItemInstance(processor, threadIndex);
+ }
+
+ @Test
+ public void testConstructor() {
+ assertSame("Processor is not set correctly.", processor, piInstance.getProcessor());
+ assertEquals("Thread index is not set correctly.", threadIndex, piInstance.getThreadIndex(), 0);
+ }
+
+ @Test
+ public void testProcessEvent() {
+ piInstance.processEvent(event);
+ new Verifications() {
+ {
+ processor.process(event);
+ times = 1;
+ }
+ };
+ }
}
diff --git a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemTest.java b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemTest.java
index ad7cd56..d148e8e 100644
--- a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemTest.java
+++ b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsProcessingItemTest.java
@@ -39,135 +39,142 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsProcessingItemTest {
- @Tested private ThreadsProcessingItem pi;
-
- @Mocked private ThreadsEngine unused;
- @Mocked private ExecutorService threadPool;
- @Mocked private ThreadsEventRunnable task;
-
- @Mocked private Processor processor, processorReplica;
- @Mocked private ThreadsStream stream;
- @Mocked private StreamDestination destination;
- @Mocked private ContentEvent event;
-
- private final int parallelism = 4;
- private final int counter = 2;
-
- private ThreadsProcessingItemInstance instance;
-
-
- @Before
- public void setUp() throws Exception {
- new NonStrictExpectations() {
- {
- processor.newProcessor(processor);
- result=processorReplica;
- }
- };
- pi = new ThreadsProcessingItem(processor, parallelism);
- }
+ @Tested
+ private ThreadsProcessingItem pi;
- @Test
- public void testConstructor() {
- assertSame("Processor was not set correctly.",processor,pi.getProcessor());
- assertEquals("Parallelism was not set correctly.",parallelism,pi.getParallelism(),0);
- }
-
- @Test
- public void testConnectInputShuffleStream() {
- new Expectations() {
- {
- destination = new StreamDestination(pi, parallelism, PartitioningScheme.SHUFFLE);
- stream.addDestination(destination);
- }
- };
- pi.connectInputShuffleStream(stream);
- }
-
- @Test
- public void testConnectInputKeyStream() {
- new Expectations() {
- {
- destination = new StreamDestination(pi, parallelism, PartitioningScheme.GROUP_BY_KEY);
- stream.addDestination(destination);
- }
- };
- pi.connectInputKeyStream(stream);
- }
-
- @Test
- public void testConnectInputAllStream() {
- new Expectations() {
- {
- destination = new StreamDestination(pi, parallelism, PartitioningScheme.BROADCAST);
- stream.addDestination(destination);
- }
- };
- pi.connectInputAllStream(stream);
- }
-
- @Test
- public void testSetupInstances() {
- new Expectations() {
- {
- for (int i=0; i<parallelism; i++) {
- processor.newProcessor(processor);
- result=processor;
-
- processor.onCreate(anyInt);
- }
- }
- };
- pi.setupInstances();
- List<ThreadsProcessingItemInstance> instances = pi.getProcessingItemInstances();
- assertNotNull("List of PI instances is null.",instances);
- assertEquals("Number of instances does not match parallelism.",parallelism,instances.size(),0);
- for(int i=0; i<instances.size();i++) {
- assertNotNull("Instance "+i+" is null.",instances.get(i));
- assertEquals("Instance "+i+" is not a ThreadsWorkerProcessingItem.",ThreadsProcessingItemInstance.class,instances.get(i).getClass());
- }
- }
-
- @Test(expected=IllegalStateException.class)
- public void testProcessEventError() {
- pi.processEvent(event, counter);
- }
-
- @Test
- public void testProcessEvent() {
- new Expectations() {
- {
- for (int i=0; i<parallelism; i++) {
- processor.newProcessor(processor);
- result=processor;
-
- processor.onCreate(anyInt);
- }
- }
- };
- pi.setupInstances();
-
- instance = pi.getProcessingItemInstances().get(counter);
- new NonStrictExpectations() {
- {
- ThreadsEngine.getThreadWithIndex(anyInt);
- result=threadPool;
-
-
- }
- };
- new Expectations() {
- {
- task = new ThreadsEventRunnable(instance, event);
- threadPool.submit(task);
- }
- };
- pi.processEvent(event, counter);
-
- }
+ @Mocked
+ private ThreadsEngine unused;
+ @Mocked
+ private ExecutorService threadPool;
+ @Mocked
+ private ThreadsEventRunnable task;
+
+ @Mocked
+ private Processor processor, processorReplica;
+ @Mocked
+ private ThreadsStream stream;
+ @Mocked
+ private StreamDestination destination;
+ @Mocked
+ private ContentEvent event;
+
+ private final int parallelism = 4;
+ private final int counter = 2;
+
+ private ThreadsProcessingItemInstance instance;
+
+ @Before
+ public void setUp() throws Exception {
+ new NonStrictExpectations() {
+ {
+ processor.newProcessor(processor);
+ result = processorReplica;
+ }
+ };
+ pi = new ThreadsProcessingItem(processor, parallelism);
+ }
+
+ @Test
+ public void testConstructor() {
+ assertSame("Processor was not set correctly.", processor, pi.getProcessor());
+ assertEquals("Parallelism was not set correctly.", parallelism, pi.getParallelism(), 0);
+ }
+
+ @Test
+ public void testConnectInputShuffleStream() {
+ new Expectations() {
+ {
+ destination = new StreamDestination(pi, parallelism, PartitioningScheme.SHUFFLE);
+ stream.addDestination(destination);
+ }
+ };
+ pi.connectInputShuffleStream(stream);
+ }
+
+ @Test
+ public void testConnectInputKeyStream() {
+ new Expectations() {
+ {
+ destination = new StreamDestination(pi, parallelism, PartitioningScheme.GROUP_BY_KEY);
+ stream.addDestination(destination);
+ }
+ };
+ pi.connectInputKeyStream(stream);
+ }
+
+ @Test
+ public void testConnectInputAllStream() {
+ new Expectations() {
+ {
+ destination = new StreamDestination(pi, parallelism, PartitioningScheme.BROADCAST);
+ stream.addDestination(destination);
+ }
+ };
+ pi.connectInputAllStream(stream);
+ }
+
+ @Test
+ public void testSetupInstances() {
+ new Expectations() {
+ {
+ for (int i = 0; i < parallelism; i++) {
+ processor.newProcessor(processor);
+ result = processor;
+
+ processor.onCreate(anyInt);
+ }
+ }
+ };
+ pi.setupInstances();
+ List<ThreadsProcessingItemInstance> instances = pi.getProcessingItemInstances();
+ assertNotNull("List of PI instances is null.", instances);
+ assertEquals("Number of instances does not match parallelism.", parallelism, instances.size(), 0);
+ for (int i = 0; i < instances.size(); i++) {
+ assertNotNull("Instance " + i + " is null.", instances.get(i));
+ assertEquals("Instance " + i + " is not a ThreadsWorkerProcessingItem.", ThreadsProcessingItemInstance.class,
+ instances.get(i).getClass());
+ }
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testProcessEventError() {
+ pi.processEvent(event, counter);
+ }
+
+ @Test
+ public void testProcessEvent() {
+ new Expectations() {
+ {
+ for (int i = 0; i < parallelism; i++) {
+ processor.newProcessor(processor);
+ result = processor;
+
+ processor.onCreate(anyInt);
+ }
+ }
+ };
+ pi.setupInstances();
+
+ instance = pi.getProcessingItemInstances().get(counter);
+ new NonStrictExpectations() {
+ {
+ ThreadsEngine.getThreadWithIndex(anyInt);
+ result = threadPool;
+
+ }
+ };
+ new Expectations() {
+ {
+ task = new ThreadsEventRunnable(instance, event);
+ threadPool.submit(task);
+ }
+ };
+ pi.processEvent(event, counter);
+
+ }
}
diff --git a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsStreamTest.java b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsStreamTest.java
index 27d2acd..abe57ce 100644
--- a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsStreamTest.java
+++ b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsStreamTest.java
@@ -41,87 +41,95 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
@RunWith(Parameterized.class)
public class ThreadsStreamTest {
-
- @Tested private ThreadsStream stream;
-
- @Mocked private ThreadsProcessingItem sourcePi, destPi;
- @Mocked private ContentEvent event;
- @Mocked private StreamDestination destination;
- private final String eventKey = "eventkey";
- private final int parallelism;
- private final PartitioningScheme scheme;
-
-
- @Parameters
- public static Collection<Object[]> generateParameters() {
- return Arrays.asList(new Object[][] {
- { 2, PartitioningScheme.SHUFFLE },
- { 3, PartitioningScheme.GROUP_BY_KEY },
- { 4, PartitioningScheme.BROADCAST }
- });
- }
-
- public ThreadsStreamTest(int parallelism, PartitioningScheme scheme) {
- this.parallelism = parallelism;
- this.scheme = scheme;
- }
-
- @Before
- public void setUp() throws Exception {
- stream = new ThreadsStream(sourcePi);
- stream.addDestination(destination);
- }
-
- @Test
- public void testAddDestination() {
- boolean found = false;
- for (StreamDestination sd:stream.getDestinations()) {
- if (sd == destination) {
- found = true;
- break;
- }
- }
- assertTrue("Destination object was not added in stream's destinations set.",found);
- }
+ @Tested
+ private ThreadsStream stream;
- @Test
- public void testPut() {
- new NonStrictExpectations() {
- {
- event.getKey(); result=eventKey;
- destination.getProcessingItem(); result=destPi;
- destination.getPartitioningScheme(); result=scheme;
- destination.getParallelism(); result=parallelism;
-
- }
- };
- switch(this.scheme) {
- case SHUFFLE: case GROUP_BY_KEY:
- new Expectations() {
- {
-
- // TODO: restrict the range of counter value
- destPi.processEvent(event, anyInt); times=1;
- }
- };
- break;
- case BROADCAST:
- new Expectations() {
- {
- // TODO: restrict the range of counter value
- destPi.processEvent(event, anyInt); times=parallelism;
- }
- };
- break;
- }
- stream.put(event);
- }
-
-
+ @Mocked
+ private ThreadsProcessingItem sourcePi, destPi;
+ @Mocked
+ private ContentEvent event;
+ @Mocked
+ private StreamDestination destination;
+
+ private final String eventKey = "eventkey";
+ private final int parallelism;
+ private final PartitioningScheme scheme;
+
+ @Parameters
+ public static Collection<Object[]> generateParameters() {
+ return Arrays.asList(new Object[][] {
+ { 2, PartitioningScheme.SHUFFLE },
+ { 3, PartitioningScheme.GROUP_BY_KEY },
+ { 4, PartitioningScheme.BROADCAST }
+ });
+ }
+
+ public ThreadsStreamTest(int parallelism, PartitioningScheme scheme) {
+ this.parallelism = parallelism;
+ this.scheme = scheme;
+ }
+
+ @Before
+ public void setUp() throws Exception {
+ stream = new ThreadsStream(sourcePi);
+ stream.addDestination(destination);
+ }
+
+ @Test
+ public void testAddDestination() {
+ boolean found = false;
+ for (StreamDestination sd : stream.getDestinations()) {
+ if (sd == destination) {
+ found = true;
+ break;
+ }
+ }
+ assertTrue("Destination object was not added in stream's destinations set.", found);
+ }
+
+ @Test
+ public void testPut() {
+ new NonStrictExpectations() {
+ {
+ event.getKey();
+ result = eventKey;
+ destination.getProcessingItem();
+ result = destPi;
+ destination.getPartitioningScheme();
+ result = scheme;
+ destination.getParallelism();
+ result = parallelism;
+
+ }
+ };
+ switch (this.scheme) {
+ case SHUFFLE:
+ case GROUP_BY_KEY:
+ new Expectations() {
+ {
+
+ // TODO: restrict the range of counter value
+ destPi.processEvent(event, anyInt);
+ times = 1;
+ }
+ };
+ break;
+ case BROADCAST:
+ new Expectations() {
+ {
+ // TODO: restrict the range of counter value
+ destPi.processEvent(event, anyInt);
+ times = parallelism;
+ }
+ };
+ break;
+ }
+ stream.put(event);
+ }
}
diff --git a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsTopologyTest.java b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsTopologyTest.java
index 46847f5..6891a63 100644
--- a/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsTopologyTest.java
+++ b/samoa-threads/src/test/java/com/yahoo/labs/samoa/topology/impl/ThreadsTopologyTest.java
@@ -35,50 +35,53 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
public class ThreadsTopologyTest {
- @Tested private ThreadsTopology topology;
-
- @Mocked private ThreadsEntranceProcessingItem entrancePi;
- @Mocked private EntranceProcessor entranceProcessor;
-
- @Before
- public void setUp() throws Exception {
- topology = new ThreadsTopology("TestTopology");
- }
+ @Tested
+ private ThreadsTopology topology;
- @Test
- public void testAddEntrancePi() {
- topology.addEntranceProcessingItem(entrancePi);
- Set<EntranceProcessingItem> entrancePIs = topology.getEntranceProcessingItems();
- assertNotNull("Set of entrance PIs is null.",entrancePIs);
- assertEquals("Number of entrance PI in ThreadsTopology must be 1",1,entrancePIs.size());
- assertSame("Entrance PI was not set correctly.",entrancePi,entrancePIs.toArray()[0]);
- // TODO: verify that entrance PI is in the set of ProcessingItems
- // Need to access topology's set of PIs (getProcessingItems() method)
- }
-
- @Test
- public void testRun() {
- topology.addEntranceProcessingItem(entrancePi);
-
- new Expectations() {
- {
- entrancePi.getProcessor();
- result=entranceProcessor;
- entranceProcessor.onCreate(anyInt);
-
- entrancePi.startSendingEvents();
- }
- };
- topology.run();
- }
-
- @Test(expected=IllegalStateException.class)
- public void testRunWithoutEntrancePI() {
- topology.run();
- }
+ @Mocked
+ private ThreadsEntranceProcessingItem entrancePi;
+ @Mocked
+ private EntranceProcessor entranceProcessor;
+
+ @Before
+ public void setUp() throws Exception {
+ topology = new ThreadsTopology("TestTopology");
+ }
+
+ @Test
+ public void testAddEntrancePi() {
+ topology.addEntranceProcessingItem(entrancePi);
+ Set<EntranceProcessingItem> entrancePIs = topology.getEntranceProcessingItems();
+ assertNotNull("Set of entrance PIs is null.", entrancePIs);
+ assertEquals("Number of entrance PI in ThreadsTopology must be 1", 1, entrancePIs.size());
+ assertSame("Entrance PI was not set correctly.", entrancePi, entrancePIs.toArray()[0]);
+ // TODO: verify that entrance PI is in the set of ProcessingItems
+ // Need to access topology's set of PIs (getProcessingItems() method)
+ }
+
+ @Test
+ public void testRun() {
+ topology.addEntranceProcessingItem(entrancePi);
+
+ new Expectations() {
+ {
+ entrancePi.getProcessor();
+ result = entranceProcessor;
+ entranceProcessor.onCreate(anyInt);
+
+ entrancePi.startSendingEvents();
+ }
+ };
+ topology.run();
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testRunWithoutEntrancePI() {
+ topology.run();
+ }
}
diff --git a/samoa-threads/src/test/java/com/yahoo/labs/samoa/utils/StreamDestinationTest.java b/samoa-threads/src/test/java/com/yahoo/labs/samoa/utils/StreamDestinationTest.java
index 19c5421..c165b3e 100644
--- a/samoa-threads/src/test/java/com/yahoo/labs/samoa/utils/StreamDestinationTest.java
+++ b/samoa-threads/src/test/java/com/yahoo/labs/samoa/utils/StreamDestinationTest.java
@@ -40,41 +40,43 @@
/**
* @author Anh Thu Vu
- *
+ *
*/
@RunWith(Parameterized.class)
public class StreamDestinationTest {
- @Tested private StreamDestination destination;
-
- @Mocked private IProcessingItem pi;
- private final int parallelism;
- private final PartitioningScheme scheme;
-
- @Parameters
- public static Collection<Object[]> generateParameters() {
- return Arrays.asList(new Object[][] {
- { 3, PartitioningScheme.SHUFFLE },
- { 2, PartitioningScheme.GROUP_BY_KEY },
- { 5, PartitioningScheme.BROADCAST }
- });
- }
-
- public StreamDestinationTest(int parallelism, PartitioningScheme scheme) {
- this.parallelism = parallelism;
- this.scheme = scheme;
- }
-
- @Before
- public void setUp() throws Exception {
- destination = new StreamDestination(pi, parallelism, scheme);
- }
+ @Tested
+ private StreamDestination destination;
- @Test
- public void testContructor() {
- assertSame("The IProcessingItem is not set correctly.", pi, destination.getProcessingItem());
- assertEquals("Parallelism value is not set correctly.", parallelism, destination.getParallelism(), 0);
- assertEquals("EventAllocationType is not set correctly.", scheme, destination.getPartitioningScheme());
- }
+ @Mocked
+ private IProcessingItem pi;
+ private final int parallelism;
+ private final PartitioningScheme scheme;
+
+ @Parameters
+ public static Collection<Object[]> generateParameters() {
+ return Arrays.asList(new Object[][] {
+ { 3, PartitioningScheme.SHUFFLE },
+ { 2, PartitioningScheme.GROUP_BY_KEY },
+ { 5, PartitioningScheme.BROADCAST }
+ });
+ }
+
+ public StreamDestinationTest(int parallelism, PartitioningScheme scheme) {
+ this.parallelism = parallelism;
+ this.scheme = scheme;
+ }
+
+ @Before
+ public void setUp() throws Exception {
+ destination = new StreamDestination(pi, parallelism, scheme);
+ }
+
+ @Test
+ public void testContructor() {
+ assertSame("The IProcessingItem is not set correctly.", pi, destination.getProcessingItem());
+ assertEquals("Parallelism value is not set correctly.", parallelism, destination.getParallelism(), 0);
+ assertEquals("EventAllocationType is not set correctly.", scheme, destination.getPartitioningScheme());
+ }
}