SAMOA-58: Refactored class index command line parameter in Avro- and ArffFileStream up to common parent FileStream.
Added class weights command line parameter to FileStream.
diff --git a/samoa-api/src/main/java/org/apache/samoa/streams/ArffFileStream.java b/samoa-api/src/main/java/org/apache/samoa/streams/ArffFileStream.java
index 417eb2e..070021e 100644
--- a/samoa-api/src/main/java/org/apache/samoa/streams/ArffFileStream.java
+++ b/samoa-api/src/main/java/org/apache/samoa/streams/ArffFileStream.java
@@ -41,9 +41,9 @@
public FileOption arffFileOption = new FileOption("arffFile", 'f',
"ARFF File(s) to load.", null, null, false);
- public IntOption classIndexOption = new IntOption("classIndex", 'c',
+ /*public IntOption classIndexOption = new IntOption("classIndex", 'c',
"Class index of data. 0 for none or -1 for last attribute in file.",
- -1, -1, Integer.MAX_VALUE);
+ -1, -1, Integer.MAX_VALUE);*/
protected InstanceExample lastInstanceRead;
private BufferedReader fileReader;
diff --git a/samoa-api/src/main/java/org/apache/samoa/streams/AvroFileStream.java b/samoa-api/src/main/java/org/apache/samoa/streams/AvroFileStream.java
index 59bf22b..7c575d0 100644
--- a/samoa-api/src/main/java/org/apache/samoa/streams/AvroFileStream.java
+++ b/samoa-api/src/main/java/org/apache/samoa/streams/AvroFileStream.java
@@ -45,8 +45,8 @@
private static final Logger logger = LoggerFactory.getLogger(AvroFileStream.class);
public FileOption avroFileOption = new FileOption("avroFile", 'f', "Avro File(s) to load.", null, null, false);
- public IntOption classIndexOption = new IntOption("classIndex", 'c',
- "Class index of data. 0 for none or -1 for last attribute in file.", -1, -1, Integer.MAX_VALUE);
+ /*public IntOption classIndexOption = new IntOption("classIndex", 'c',
+ "Class index of data. 0 for none or -1 for last attribute in file.", -1, -1, Integer.MAX_VALUE);*/
public StringOption encodingFormatOption = new StringOption("encodingFormatOption", 'e',
"Encoding format for Avro Files. Can be JSON/AVRO", "BINARY");
diff --git a/samoa-api/src/main/java/org/apache/samoa/streams/FileStream.java b/samoa-api/src/main/java/org/apache/samoa/streams/FileStream.java
index cfa8de5..d9a7554 100644
--- a/samoa-api/src/main/java/org/apache/samoa/streams/FileStream.java
+++ b/samoa-api/src/main/java/org/apache/samoa/streams/FileStream.java
@@ -20,12 +20,7 @@
* #L%
*/
-import java.io.BufferedReader;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.io.Reader;
-
+import com.github.javacliparser.*;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.instances.InstancesHeader;
import org.apache.samoa.moa.core.InstanceExample;
@@ -34,7 +29,8 @@
import org.apache.samoa.moa.tasks.TaskMonitor;
import org.apache.samoa.streams.fs.FileStreamSource;
-import com.github.javacliparser.ClassOption;
+import java.io.IOException;
+import java.io.InputStream;
/**
* InstanceStream for files (Abstract class: subclass this class for different file formats)
@@ -51,10 +47,18 @@
's', "Source Type (HDFS, local FS)", FileStreamSource.class,
"LocalFileStreamSource");
+ public IntOption classIndexOption = new IntOption("classIndex", 'c',
+ "Class index of data. 0 for none or -1 for last attribute in file.", -1, -1, Integer.MAX_VALUE);
+
+ private FloatOption floatOption = new FloatOption("classWeight", 'w', "", 1.0);
+ public ListOption classWeightsOption = new ListOption("classWeights", 'w',
+ "Class weights in order of class index.", floatOption, new FloatOption[0], ':');
+
protected transient FileStreamSource fileSource;
//protected transient Reader fileReader;
protected transient InputStream inputStream;
protected Instances instances;
+ protected FloatOption[] classWeights;
protected boolean hitEndOfStream;
private boolean hasStarted;
@@ -99,6 +103,13 @@
readNextInstanceFromStream();
}
InstanceExample prevInstance = this.getLastInstanceRead();
+ if (classWeights != null && classWeights.length > 0) {
+ int i = (int) prevInstance.instance.classValue();
+ double w = 1.0;
+ if (i>=0 && i<classWeights.length)
+ w = classWeights[i].getValue();
+ prevInstance.setWeight(w);
+ }
readNextInstanceFromStream();
return prevInstance;
}
@@ -158,6 +169,8 @@
@Override
public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
this.fileSource = sourceTypeOption.getValue();
+ this.classWeights = (FloatOption[]) classWeightsOption.getList();
this.hasStarted = false;
}
+
}