blob: 46741d77f6c42516e3aa8342b23697c0062a47d8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samoa.instances;
import java.io.Serializable;
public class MultiLabelPrediction implements Prediction, Serializable {
protected DoubleVector[] prediction;
public MultiLabelPrediction() {
this(0);
}
public MultiLabelPrediction(int numOutputAttributes) {
prediction = new DoubleVector[numOutputAttributes];
for (int i = 0; i < numOutputAttributes; i++)
prediction[i] = new DoubleVector();
}
@Override
public int numOutputAttributes() {
return prediction.length;
}
@Override
public int numClasses(int outputAttributeIndex) {
int ret = 0;
if (prediction.length > outputAttributeIndex) {
ret = prediction[outputAttributeIndex].numValues();
}
return ret;
}
@Override
public double[] getVotes(int outputAttributeIndex) {
double ret[] = null;
if (prediction.length > outputAttributeIndex) {
ret = prediction[outputAttributeIndex].getArrayCopy();
}
return ret;
}
@Override
public double[] getVotes() {
return getVotes(0);
}
@Override
public void setVotes(double[] votes) {
setVotes(0, votes);
}
@Override
public double getVote(int outputAttributeIndex, int classIndex) {
double ret = 0.0;
if (prediction.length > outputAttributeIndex) {
ret = prediction[outputAttributeIndex].getValue(classIndex);
}
return ret;
}
@Override
public void setVotes(int outputAttributeIndex, double[] votes) {
for (int i = 0; i < votes.length; i++)
prediction[outputAttributeIndex].setValue(i, votes[i]);
}
@Override
public void setVote(int outputAttributeIndex, int classIndex, double vote) {
prediction[outputAttributeIndex].setValue(classIndex, vote);
}
@Override
public String toString() {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < prediction.length; i++) {
sb.append("Out " + i + ": ");
for (int c = 0; c < prediction[i].numValues(); c++) {
sb.append(((int) (prediction[i].getValue(c) * 1000) / 1000.0) + " ");
}
}
return sb.toString();
}
@Override
public boolean hasVotesForAttribute(int outputAttributeIndex) {
if (prediction.length < (outputAttributeIndex + 1))
return false;
return (prediction[outputAttributeIndex].numValues() == 0) ? false : true;
}
@Override
public int size() {
return prediction.length;
}
protected class DoubleVector implements Serializable {
private static final long serialVersionUID = 1L;
protected double[] array;
public DoubleVector() {
this.array = new double[0];
}
public DoubleVector(double[] toCopy) {
this.array = new double[toCopy.length];
System.arraycopy(toCopy, 0, this.array, 0, toCopy.length);
}
public int numValues() {
return this.array.length;
}
public void setValue(int i, double v) {
if (i >= this.array.length) {
setArrayLength(i + 1);
}
this.array[i] = v;
}
public void addToValue(int i, double v) {
if (i >= this.array.length) {
setArrayLength(i + 1);
}
this.array[i] += v;
}
// returns 0.0 for values outside of range
public double getValue(int i) {
return ((i >= 0) && (i < this.array.length)) ? this.array[i] : 0.0;
}
public int maxIndex() {
int max = -1;
for (int i = 0; i < this.array.length; i++) {
if ((max < 0) || (this.array[i] > this.array[max])) {
max = i;
}
}
return max;
}
public double[] getArrayCopy() {
double[] aCopy = new double[this.array.length];
System.arraycopy(this.array, 0, aCopy, 0, this.array.length);
return aCopy;
}
protected void setArrayLength(int l) {
double[] newArray = new double[l];
int numToCopy = this.array.length;
if (numToCopy > l) {
numToCopy = l;
}
System.arraycopy(this.array, 0, newArray, 0, numToCopy);
this.array = newArray;
}
}
}