blob: 5b822949ee81c50f5eef2a8c92868a703010e01f [file] [log] [blame]
package org.apache.samoa.instances;
/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2015 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import java.io.Serializable;
public class MultiLabelPrediction implements Prediction, Serializable {
protected DoubleVector[] prediction;
public MultiLabelPrediction() {
this(0);
}
public MultiLabelPrediction(int numOutputAttributes) {
prediction = new DoubleVector[numOutputAttributes];
for (int i = 0; i < numOutputAttributes; i++)
prediction[i] = new DoubleVector();
}
@Override
public int numOutputAttributes() {
return prediction.length;
}
@Override
public int numClasses(int outputAttributeIndex) {
int ret = 0;
if (prediction.length > outputAttributeIndex) {
ret = prediction[outputAttributeIndex].numValues();
}
return ret;
}
@Override
public double[] getVotes(int outputAttributeIndex) {
double ret[] = null;
if (prediction.length > outputAttributeIndex) {
ret = prediction[outputAttributeIndex].getArrayCopy();
}
return ret;
}
@Override
public double[] getVotes() {
return getVotes(0);
}
@Override
public void setVotes(double[] votes) {
setVotes(0, votes);
}
@Override
public double getVote(int outputAttributeIndex, int classIndex) {
double ret = 0.0;
if (prediction.length > outputAttributeIndex) {
ret = prediction[outputAttributeIndex].getValue(classIndex);
}
return ret;
}
@Override
public void setVotes(int outputAttributeIndex, double[] votes) {
for (int i = 0; i < votes.length; i++)
prediction[outputAttributeIndex].setValue(i, votes[i]);
}
@Override
public void setVote(int outputAttributeIndex, int classIndex, double vote) {
prediction[outputAttributeIndex].setValue(classIndex, vote);
}
@Override
public String toString() {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < prediction.length; i++) {
sb.append("Out " + i + ": ");
for (int c = 0; c < prediction[i].numValues(); c++) {
sb.append(((int) (prediction[i].getValue(c) * 1000) / 1000.0) + " ");
}
}
return sb.toString();
}
@Override
public boolean hasVotesForAttribute(int outputAttributeIndex) {
if (prediction.length < (outputAttributeIndex + 1))
return false;
return (prediction[outputAttributeIndex].numValues() == 0) ? false : true;
}
@Override
public int size() {
return prediction.length;
}
protected class DoubleVector implements Serializable {
private static final long serialVersionUID = 1L;
protected double[] array;
public DoubleVector() {
this.array = new double[0];
}
public DoubleVector(double[] toCopy) {
this.array = new double[toCopy.length];
System.arraycopy(toCopy, 0, this.array, 0, toCopy.length);
}
public int numValues() {
return this.array.length;
}
public void setValue(int i, double v) {
if (i >= this.array.length) {
setArrayLength(i + 1);
}
this.array[i] = v;
}
public void addToValue(int i, double v) {
if (i >= this.array.length) {
setArrayLength(i + 1);
}
this.array[i] += v;
}
// returns 0.0 for values outside of range
public double getValue(int i) {
return ((i >= 0) && (i < this.array.length)) ? this.array[i] : 0.0;
}
public int maxIndex() {
int max = -1;
for (int i = 0; i < this.array.length; i++) {
if ((max < 0) || (this.array[i] > this.array[max])) {
max = i;
}
}
return max;
}
public double[] getArrayCopy() {
double[] aCopy = new double[this.array.length];
System.arraycopy(this.array, 0, aCopy, 0, this.array.length);
return aCopy;
}
protected void setArrayLength(int l) {
double[] newArray = new double[l];
int numToCopy = this.array.length;
if (numToCopy > l) {
numToCopy = l;
}
System.arraycopy(this.array, 0, newArray, 0, numToCopy);
this.array = newArray;
}
}
}