blob: 4ca4c0f78094f53e2126d5187a7fd184153ad84f [file] [log] [blame]
package org.apache.samoa.moa.classifiers;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.InstancesHeader;
import org.apache.samoa.moa.MOAObject;
import org.apache.samoa.moa.core.Example;
import org.apache.samoa.moa.core.Measurement;
import org.apache.samoa.moa.core.ObjectRepository;
import org.apache.samoa.moa.core.StringUtils;
import org.apache.samoa.moa.core.Utils;
import org.apache.samoa.moa.learners.Learner;
import org.apache.samoa.moa.options.AbstractOptionHandler;
import org.apache.samoa.moa.tasks.TaskMonitor;
import com.github.javacliparser.IntOption;
public abstract class AbstractClassifier extends AbstractOptionHandler implements Classifier {
@Override
public String getPurposeString() {
return "MOA Classifier: " + getClass().getCanonicalName();
}
/** Header of the instances of the data stream */
protected InstancesHeader modelContext;
/** Sum of the weights of the instances trained by this model */
protected double trainingWeightSeenByModel = 0.0;
/** Random seed used in randomizable learners */
protected int randomSeed = 1;
/** Option for randomizable learners to change the random seed */
protected IntOption randomSeedOption;
/** Random Generator used in randomizable learners */
public Random classifierRandom;
/**
* Creates an classifier and setups the random seed option if the classifier is randomizable.
*/
public AbstractClassifier() {
if (isRandomizable()) {
this.randomSeedOption = new IntOption("randomSeed", 'r',
"Seed for random behaviour of the classifier.", 1);
}
}
@Override
public void prepareForUseImpl(TaskMonitor monitor,
ObjectRepository repository) {
if (this.randomSeedOption != null) {
this.randomSeed = this.randomSeedOption.getValue();
}
if (!trainingHasStarted()) {
resetLearning();
}
}
@Override
public double[] getVotesForInstance(Example<Instance> example) {
return getVotesForInstance(example.getData());
}
@Override
public abstract double[] getVotesForInstance(Instance inst);
@Override
public void setModelContext(InstancesHeader ih) {
if ((ih != null) && (ih.classIndex() < 0)) {
throw new IllegalArgumentException(
"Context for a classifier must include a class to learn");
}
if (trainingHasStarted()
&& (this.modelContext != null)
&& ((ih == null) || !contextIsCompatible(this.modelContext, ih))) {
throw new IllegalArgumentException(
"New context is not compatible with existing model");
}
this.modelContext = ih;
}
@Override
public InstancesHeader getModelContext() {
return this.modelContext;
}
@Override
public void setRandomSeed(int s) {
this.randomSeed = s;
if (this.randomSeedOption != null) {
// keep option consistent
this.randomSeedOption.setValue(s);
}
}
@Override
public boolean trainingHasStarted() {
return this.trainingWeightSeenByModel > 0.0;
}
@Override
public double trainingWeightSeenByModel() {
return this.trainingWeightSeenByModel;
}
@Override
public void resetLearning() {
this.trainingWeightSeenByModel = 0.0;
if (isRandomizable()) {
this.classifierRandom = new Random(this.randomSeed);
}
resetLearningImpl();
}
@Override
public void trainOnInstance(Instance inst) {
if (inst.weight() > 0.0) {
this.trainingWeightSeenByModel += inst.weight();
trainOnInstanceImpl(inst);
}
}
@Override
public Measurement[] getModelMeasurements() {
List<Measurement> measurementList = new LinkedList<>();
measurementList.add(new Measurement("model training instances",
trainingWeightSeenByModel()));
measurementList.add(new Measurement("model serialized size (bytes)",
measureByteSize()));
Measurement[] modelMeasurements = getModelMeasurementsImpl();
if (modelMeasurements != null) {
measurementList.addAll(Arrays.asList(modelMeasurements));
}
// add average of sub-model measurements
Learner[] subModels = getSublearners();
if ((subModels != null) && (subModels.length > 0)) {
List<Measurement[]> subMeasurements = new LinkedList<>();
for (Learner subModel : subModels) {
if (subModel != null) {
subMeasurements.add(subModel.getModelMeasurements());
}
}
Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements
.toArray(new Measurement[subMeasurements.size()][]));
measurementList.addAll(Arrays.asList(avgMeasurements));
}
return measurementList.toArray(new Measurement[measurementList.size()]);
}
@Override
public void getDescription(StringBuilder out, int indent) {
StringUtils.appendIndented(out, indent, "Model type: ");
out.append(this.getClass().getName());
StringUtils.appendNewline(out);
Measurement.getMeasurementsDescription(getModelMeasurements(), out,
indent);
StringUtils.appendNewlineIndented(out, indent, "Model description:");
StringUtils.appendNewline(out);
if (trainingHasStarted()) {
getModelDescription(out, indent);
} else {
StringUtils.appendIndented(out, indent,
"Model has not been trained.");
}
}
@Override
public Learner[] getSublearners() {
return null;
}
@Override
public Classifier[] getSubClassifiers() {
return null;
}
@Override
public Classifier copy() {
return (Classifier) super.copy();
}
@Override
public MOAObject getModel() {
return this;
}
@Override
public void trainOnInstance(Example<Instance> example) {
trainOnInstance(example.getData());
}
@Override
public boolean correctlyClassifies(Instance inst) {
return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst.classValue();
}
/**
* Gets the name of the attribute of the class from the header.
*
* @return the string with name of the attribute of the class
*/
public String getClassNameString() {
return InstancesHeader.getClassNameString(this.modelContext);
}
/**
* Gets the name of a label of the class from the header.
*
* @param classLabelIndex
* the label index
* @return the name of the label of the class
*/
public String getClassLabelString(int classLabelIndex) {
return InstancesHeader.getClassLabelString(this.modelContext,
classLabelIndex);
}
/**
* Gets the name of an attribute from the header.
*
* @param attIndex
* the attribute index
* @return the name of the attribute
*/
public String getAttributeNameString(int attIndex) {
return InstancesHeader.getAttributeNameString(this.modelContext, attIndex);
}
/**
* Gets the name of a value of an attribute from the header.
*
* @param attIndex
* the attribute index
* @param valIndex
* the value of the attribute
* @return the name of the value of the attribute
*/
public String getNominalValueString(int attIndex, int valIndex) {
return InstancesHeader.getNominalValueString(this.modelContext, attIndex, valIndex);
}
/**
* Returns if two contexts or headers of instances are compatible.<br>
* <br>
*
* Two contexts are compatible if they follow the following rules:<br>
* Rule 1: num classes can increase but never decrease<br>
* Rule 2: num attributes can increase but never decrease<br>
* Rule 3: num nominal attribute values can increase but never decrease<br>
* Rule 4: attribute types must stay in the same order (although class can move; is always skipped over)<br>
* <br>
*
* Attribute names are free to change, but should always still represent the original attributes.
*
* @param originalContext
* the first context to compare
* @param newContext
* the second context to compare
* @return true if the two contexts are compatible.
*/
public static boolean contextIsCompatible(InstancesHeader originalContext,
InstancesHeader newContext) {
if (newContext.numClasses() < originalContext.numClasses()) {
return false; // rule 1
}
if (newContext.numAttributes() < originalContext.numAttributes()) {
return false; // rule 2
}
int oPos = 0;
int nPos = 0;
while (oPos < originalContext.numAttributes()) {
if (oPos == originalContext.classIndex()) {
oPos++;
if (!(oPos < originalContext.numAttributes())) {
break;
}
}
if (nPos == newContext.classIndex()) {
nPos++;
}
if (originalContext.attribute(oPos).isNominal()) {
if (!newContext.attribute(nPos).isNominal()) {
return false; // rule 4
}
if (newContext.attribute(nPos).numValues() < originalContext.attribute(oPos).numValues()) {
return false; // rule 3
}
} else {
assert (originalContext.attribute(oPos).isNumeric());
if (!newContext.attribute(nPos).isNumeric()) {
return false; // rule 4
}
}
oPos++;
nPos++;
}
return true; // all checks clear
}
/**
* Resets this classifier. It must be similar to starting a new classifier from scratch. <br>
* <br>
*
* The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in
* overridden methods. Note that this will produce compiler errors if not overridden.
*/
public abstract void resetLearningImpl();
/**
* Trains this classifier incrementally using the given instance.<br>
* <br>
*
* The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in
* overridden methods. Note that this will produce compiler errors if not overridden.
*
* @param inst
* the instance to be used for training
*/
public abstract void trainOnInstanceImpl(Instance inst);
/**
* Gets the current measurements of this classifier.<br>
* <br>
*
* The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in
* overridden methods. Note that this will produce compiler errors if not overridden.
*
* @return an array of measurements to be used in evaluation tasks
*/
protected abstract Measurement[] getModelMeasurementsImpl();
/**
* Returns a string representation of the model.
*
* @param out
* the stringbuilder to add the description
* @param indent
* the number of characters to indent
*/
public abstract void getModelDescription(StringBuilder out, int indent);
/**
* Gets the index of the attribute in the instance, given the index of the attribute in the learner.
*
* @param index
* the index of the attribute in the learner
* @return the index in the instance
*/
protected static int modelAttIndexToInstanceAttIndex(int index) {
return index; // inst.classIndex() > index ? index : index + 1;
}
}