blob: 46741d77f6c42516e3aa8342b23697c0062a47d8 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package org.apache.samoa.instances;
public class MultiLabelPrediction implements Prediction, Serializable {
protected DoubleVector[] prediction;
public MultiLabelPrediction() {
public MultiLabelPrediction(int numOutputAttributes) {
prediction = new DoubleVector[numOutputAttributes];
for (int i = 0; i < numOutputAttributes; i++)
prediction[i] = new DoubleVector();
public int numOutputAttributes() {
return prediction.length;
public int numClasses(int outputAttributeIndex) {
int ret = 0;
if (prediction.length > outputAttributeIndex) {
ret = prediction[outputAttributeIndex].numValues();
return ret;
public double[] getVotes(int outputAttributeIndex) {
double ret[] = null;
if (prediction.length > outputAttributeIndex) {
ret = prediction[outputAttributeIndex].getArrayCopy();
return ret;
public double[] getVotes() {
return getVotes(0);
public void setVotes(double[] votes) {
setVotes(0, votes);
public double getVote(int outputAttributeIndex, int classIndex) {
double ret = 0.0;
if (prediction.length > outputAttributeIndex) {
ret = prediction[outputAttributeIndex].getValue(classIndex);
return ret;
public void setVotes(int outputAttributeIndex, double[] votes) {
for (int i = 0; i < votes.length; i++)
prediction[outputAttributeIndex].setValue(i, votes[i]);
public void setVote(int outputAttributeIndex, int classIndex, double vote) {
prediction[outputAttributeIndex].setValue(classIndex, vote);
public String toString() {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < prediction.length; i++) {
sb.append("Out " + i + ": ");
for (int c = 0; c < prediction[i].numValues(); c++) {
sb.append(((int) (prediction[i].getValue(c) * 1000) / 1000.0) + " ");
return sb.toString();
public boolean hasVotesForAttribute(int outputAttributeIndex) {
if (prediction.length < (outputAttributeIndex + 1))
return false;
return (prediction[outputAttributeIndex].numValues() == 0) ? false : true;
public int size() {
return prediction.length;
protected class DoubleVector implements Serializable {
private static final long serialVersionUID = 1L;
protected double[] array;
public DoubleVector() {
this.array = new double[0];
public DoubleVector(double[] toCopy) {
this.array = new double[toCopy.length];
System.arraycopy(toCopy, 0, this.array, 0, toCopy.length);
public int numValues() {
return this.array.length;
public void setValue(int i, double v) {
if (i >= this.array.length) {
setArrayLength(i + 1);
this.array[i] = v;
public void addToValue(int i, double v) {
if (i >= this.array.length) {
setArrayLength(i + 1);
this.array[i] += v;
// returns 0.0 for values outside of range
public double getValue(int i) {
return ((i >= 0) && (i < this.array.length)) ? this.array[i] : 0.0;
public int maxIndex() {
int max = -1;
for (int i = 0; i < this.array.length; i++) {
if ((max < 0) || (this.array[i] > this.array[max])) {
max = i;
return max;
public double[] getArrayCopy() {
double[] aCopy = new double[this.array.length];
System.arraycopy(this.array, 0, aCopy, 0, this.array.length);
return aCopy;
protected void setArrayLength(int l) {
double[] newArray = new double[l];
int numToCopy = this.array.length;
if (numToCopy > l) {
numToCopy = l;
System.arraycopy(this.array, 0, newArray, 0, numToCopy);
this.array = newArray;