blob: 6c358ddd31148f8318220d8a29539a818458d23a [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.s4.model;
import java.util.Random;
import org.ejml.data.DenseMatrix64F;
import junit.framework.Assert;
import junit.framework.TestCase;
public class TestGaussianModel extends TestCase {
private int NUM_VECTORS = 100000;
private double mean[] = { 153, 10.0, 5.0, 0.1 };
private double std[] = { 30, 2.0, 1.0, 5.5 };
private int numElements = mean.length;
private DenseMatrix64F vectors[] = new DenseMatrix64F[NUM_VECTORS];
private double doubleArrays[][] = new double[NUM_VECTORS][numElements];
private float floatArrays[][] = new float[NUM_VECTORS][numElements];
private Random random = new Random(0);
protected void setUp() {
/* Generate the data set. */
for (int i = 0; i < NUM_VECTORS; i++) {
vectors[i] = new DenseMatrix64F(numElements, 1);
for (int j = 0; j < numElements; j++) {
double v = mean[j] + std[j] * random.nextGaussian();
vectors[i].set(j, v);
doubleArrays[i][j] = v;
floatArrays[i][j] = (float)v;
}
}
}
public void testTrainerUsingEJML() {
GaussianModel gm = new GaussianModel(numElements, true);
for (int i = 0; i < NUM_VECTORS; i++) {
gm.update(vectors[i]);
}
gm.estimate();
System.out.println(gm);
double[] actualMean = gm.getMean();
for (int j = 0; j < mean.length; j++) {
Assert.assertEquals("Assert mean.", mean[j], actualMean[j], std[j]);
}
}
public void testTrainerUsingDoubleArray() {
GaussianModel gm = new GaussianModel(numElements, true);
for (int i = 0; i < NUM_VECTORS; i++) {
gm.update(doubleArrays[i]);
}
gm.estimate();
System.out.println(gm);
double[] actualMean = gm.getMean();
for (int j = 0; j < mean.length; j++) {
Assert.assertEquals("Assert mean.", mean[j], actualMean[j], std[j]);
}
}
public void testTrainerUsingFloatArray() {
GaussianModel gm = new GaussianModel(numElements, true);
for (int i = 0; i < NUM_VECTORS; i++) {
gm.update(floatArrays[i]);
}
gm.estimate();
System.out.println(gm);
double[] actualMean = gm.getMean();
for (int j = 0; j < mean.length; j++) {
Assert.assertEquals("Assert mean.", mean[j], actualMean[j], std[j]);
}
}
}