| |
| package com.yahoo.labs.samoa.moa.classifiers; |
| |
| /* |
| * #%L |
| * SAMOA |
| * %% |
| * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand |
| * %% |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * #L% |
| */ |
| |
| import java.util.Arrays; |
| import java.util.LinkedList; |
| import java.util.List; |
| import java.util.Random; |
| |
| import com.github.javacliparser.IntOption; |
| import com.yahoo.labs.samoa.instances.Instance; |
| import com.yahoo.labs.samoa.instances.InstancesHeader; |
| import com.yahoo.labs.samoa.moa.MOAObject; |
| import com.yahoo.labs.samoa.moa.core.Example; |
| import com.yahoo.labs.samoa.moa.core.Measurement; |
| import com.yahoo.labs.samoa.moa.core.ObjectRepository; |
| import com.yahoo.labs.samoa.moa.core.StringUtils; |
| import com.yahoo.labs.samoa.moa.core.Utils; |
| import com.yahoo.labs.samoa.moa.learners.Learner; |
| import com.yahoo.labs.samoa.moa.options.AbstractOptionHandler; |
| import com.yahoo.labs.samoa.moa.tasks.TaskMonitor; |
| |
| public abstract class AbstractClassifier extends AbstractOptionHandler implements Classifier { |
| |
| @Override |
| public String getPurposeString() { |
| return "MOA Classifier: " + getClass().getCanonicalName(); |
| } |
| |
| /** Header of the instances of the data stream */ |
| protected InstancesHeader modelContext; |
| |
| /** Sum of the weights of the instances trained by this model */ |
| protected double trainingWeightSeenByModel = 0.0; |
| |
| /** Random seed used in randomizable learners */ |
| protected int randomSeed = 1; |
| |
| /** Option for randomizable learners to change the random seed */ |
| protected IntOption randomSeedOption; |
| |
| /** Random Generator used in randomizable learners */ |
| public Random classifierRandom; |
| |
| /** |
| * Creates an classifier and setups the random seed option |
| * if the classifier is randomizable. |
| */ |
| public AbstractClassifier() { |
| if (isRandomizable()) { |
| this.randomSeedOption = new IntOption("randomSeed", 'r', |
| "Seed for random behaviour of the classifier.", 1); |
| } |
| } |
| |
| @Override |
| public void prepareForUseImpl(TaskMonitor monitor, |
| ObjectRepository repository) { |
| if (this.randomSeedOption != null) { |
| this.randomSeed = this.randomSeedOption.getValue(); |
| } |
| if (!trainingHasStarted()) { |
| resetLearning(); |
| } |
| } |
| |
| |
| @Override |
| public double[] getVotesForInstance(Example<Instance> example){ |
| return getVotesForInstance(example.getData()); |
| } |
| |
| @Override |
| public abstract double[] getVotesForInstance(Instance inst); |
| |
| @Override |
| public void setModelContext(InstancesHeader ih) { |
| if ((ih != null) && (ih.classIndex() < 0)) { |
| throw new IllegalArgumentException( |
| "Context for a classifier must include a class to learn"); |
| } |
| if (trainingHasStarted() |
| && (this.modelContext != null) |
| && ((ih == null) || !contextIsCompatible(this.modelContext, ih))) { |
| throw new IllegalArgumentException( |
| "New context is not compatible with existing model"); |
| } |
| this.modelContext = ih; |
| } |
| |
| @Override |
| public InstancesHeader getModelContext() { |
| return this.modelContext; |
| } |
| |
| @Override |
| public void setRandomSeed(int s) { |
| this.randomSeed = s; |
| if (this.randomSeedOption != null) { |
| // keep option consistent |
| this.randomSeedOption.setValue(s); |
| } |
| } |
| |
| @Override |
| public boolean trainingHasStarted() { |
| return this.trainingWeightSeenByModel > 0.0; |
| } |
| |
| @Override |
| public double trainingWeightSeenByModel() { |
| return this.trainingWeightSeenByModel; |
| } |
| |
| @Override |
| public void resetLearning() { |
| this.trainingWeightSeenByModel = 0.0; |
| if (isRandomizable()) { |
| this.classifierRandom = new Random(this.randomSeed); |
| } |
| resetLearningImpl(); |
| } |
| |
| @Override |
| public void trainOnInstance(Instance inst) { |
| if (inst.weight() > 0.0) { |
| this.trainingWeightSeenByModel += inst.weight(); |
| trainOnInstanceImpl(inst); |
| } |
| } |
| |
| @Override |
| public Measurement[] getModelMeasurements() { |
| List<Measurement> measurementList = new LinkedList<>(); |
| measurementList.add(new Measurement("model training instances", |
| trainingWeightSeenByModel())); |
| measurementList.add(new Measurement("model serialized size (bytes)", |
| measureByteSize())); |
| Measurement[] modelMeasurements = getModelMeasurementsImpl(); |
| if (modelMeasurements != null) { |
| measurementList.addAll(Arrays.asList(modelMeasurements)); |
| } |
| // add average of sub-model measurements |
| Learner[] subModels = getSublearners(); |
| if ((subModels != null) && (subModels.length > 0)) { |
| List<Measurement[]> subMeasurements = new LinkedList<>(); |
| for (Learner subModel : subModels) { |
| if (subModel != null) { |
| subMeasurements.add(subModel.getModelMeasurements()); |
| } |
| } |
| Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][])); |
| measurementList.addAll(Arrays.asList(avgMeasurements)); |
| } |
| return measurementList.toArray(new Measurement[measurementList.size()]); |
| } |
| |
| @Override |
| public void getDescription(StringBuilder out, int indent) { |
| StringUtils.appendIndented(out, indent, "Model type: "); |
| out.append(this.getClass().getName()); |
| StringUtils.appendNewline(out); |
| Measurement.getMeasurementsDescription(getModelMeasurements(), out, |
| indent); |
| StringUtils.appendNewlineIndented(out, indent, "Model description:"); |
| StringUtils.appendNewline(out); |
| if (trainingHasStarted()) { |
| getModelDescription(out, indent); |
| } else { |
| StringUtils.appendIndented(out, indent, |
| "Model has not been trained."); |
| } |
| } |
| |
| @Override |
| public Learner[] getSublearners() { |
| return null; |
| } |
| |
| |
| @Override |
| public Classifier[] getSubClassifiers() { |
| return null; |
| } |
| |
| |
| @Override |
| public Classifier copy() { |
| return (Classifier) super.copy(); |
| } |
| |
| |
| @Override |
| public MOAObject getModel(){ |
| return this; |
| } |
| |
| @Override |
| public void trainOnInstance(Example<Instance> example){ |
| trainOnInstance(example.getData()); |
| } |
| |
| @Override |
| public boolean correctlyClassifies(Instance inst) { |
| return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst.classValue(); |
| } |
| |
| /** |
| * Gets the name of the attribute of the class from the header. |
| * |
| * @return the string with name of the attribute of the class |
| */ |
| public String getClassNameString() { |
| return InstancesHeader.getClassNameString(this.modelContext); |
| } |
| |
| /** |
| * Gets the name of a label of the class from the header. |
| * |
| * @param classLabelIndex the label index |
| * @return the name of the label of the class |
| */ |
| public String getClassLabelString(int classLabelIndex) { |
| return InstancesHeader.getClassLabelString(this.modelContext, |
| classLabelIndex); |
| } |
| |
| /** |
| * Gets the name of an attribute from the header. |
| * |
| * @param attIndex the attribute index |
| * @return the name of the attribute |
| */ |
| public String getAttributeNameString(int attIndex) { |
| return InstancesHeader.getAttributeNameString(this.modelContext, attIndex); |
| } |
| |
| /** |
| * Gets the name of a value of an attribute from the header. |
| * |
| * @param attIndex the attribute index |
| * @param valIndex the value of the attribute |
| * @return the name of the value of the attribute |
| */ |
| public String getNominalValueString(int attIndex, int valIndex) { |
| return InstancesHeader.getNominalValueString(this.modelContext, attIndex, valIndex); |
| } |
| |
| |
| /** |
| * Returns if two contexts or headers of instances are compatible.<br><br> |
| * |
| * Two contexts are compatible if they follow the following rules:<br> |
| * Rule 1: num classes can increase but never decrease<br> |
| * Rule 2: num attributes can increase but never decrease<br> |
| * Rule 3: num nominal attribute values can increase but never decrease<br> |
| * Rule 4: attribute types must stay in the same order (although class |
| * can move; is always skipped over)<br><br> |
| * |
| * Attribute names are free to change, but should always still represent |
| * the original attributes. |
| * |
| * @param originalContext the first context to compare |
| * @param newContext the second context to compare |
| * @return true if the two contexts are compatible. |
| */ |
| public static boolean contextIsCompatible(InstancesHeader originalContext, |
| InstancesHeader newContext) { |
| |
| if (newContext.numClasses() < originalContext.numClasses()) { |
| return false; // rule 1 |
| } |
| if (newContext.numAttributes() < originalContext.numAttributes()) { |
| return false; // rule 2 |
| } |
| int oPos = 0; |
| int nPos = 0; |
| while (oPos < originalContext.numAttributes()) { |
| if (oPos == originalContext.classIndex()) { |
| oPos++; |
| if (!(oPos < originalContext.numAttributes())) { |
| break; |
| } |
| } |
| if (nPos == newContext.classIndex()) { |
| nPos++; |
| } |
| if (originalContext.attribute(oPos).isNominal()) { |
| if (!newContext.attribute(nPos).isNominal()) { |
| return false; // rule 4 |
| } |
| if (newContext.attribute(nPos).numValues() < originalContext.attribute(oPos).numValues()) { |
| return false; // rule 3 |
| } |
| } else { |
| assert (originalContext.attribute(oPos).isNumeric()); |
| if (!newContext.attribute(nPos).isNumeric()) { |
| return false; // rule 4 |
| } |
| } |
| oPos++; |
| nPos++; |
| } |
| return true; // all checks clear |
| } |
| |
| |
| |
| /** |
| * Resets this classifier. It must be similar to |
| * starting a new classifier from scratch. <br><br> |
| * |
| * The reason for ...Impl methods: ease programmer burden by not requiring |
| * them to remember calls to super in overridden methods. |
| * Note that this will produce compiler errors if not overridden. |
| */ |
| public abstract void resetLearningImpl(); |
| |
| /** |
| * Trains this classifier incrementally using the given instance.<br><br> |
| * |
| * The reason for ...Impl methods: ease programmer burden by not requiring |
| * them to remember calls to super in overridden methods. |
| * Note that this will produce compiler errors if not overridden. |
| * |
| * @param inst the instance to be used for training |
| */ |
| public abstract void trainOnInstanceImpl(Instance inst); |
| |
| /** |
| * Gets the current measurements of this classifier.<br><br> |
| * |
| * The reason for ...Impl methods: ease programmer burden by not requiring |
| * them to remember calls to super in overridden methods. |
| * Note that this will produce compiler errors if not overridden. |
| * |
| * @return an array of measurements to be used in evaluation tasks |
| */ |
| protected abstract Measurement[] getModelMeasurementsImpl(); |
| |
| /** |
| * Returns a string representation of the model. |
| * |
| * @param out the stringbuilder to add the description |
| * @param indent the number of characters to indent |
| */ |
| public abstract void getModelDescription(StringBuilder out, int indent); |
| |
| /** |
| * Gets the index of the attribute in the instance, |
| * given the index of the attribute in the learner. |
| * |
| * @param index the index of the attribute in the learner |
| * @return the index in the instance |
| */ |
| protected static int modelAttIndexToInstanceAttIndex(int index) { |
| return index; //inst.classIndex() > index ? index : index + 1; |
| } |
| } |