| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.s4.model; |
| |
| import java.util.Random; |
| |
| import org.ejml.data.DenseMatrix64F; |
| |
| import junit.framework.Assert; |
| import junit.framework.TestCase; |
| |
| public class TestGaussianModel extends TestCase { |
| |
| private int NUM_VECTORS = 100000; |
| private double mean[] = { 153, 10.0, 5.0, 0.1 }; |
| private double std[] = { 30, 2.0, 1.0, 5.5 }; |
| private int numElements = mean.length; |
| private DenseMatrix64F vectors[] = new DenseMatrix64F[NUM_VECTORS]; |
| private double doubleArrays[][] = new double[NUM_VECTORS][numElements]; |
| private float floatArrays[][] = new float[NUM_VECTORS][numElements]; |
| |
| private Random random = new Random(0); |
| |
| protected void setUp() { |
| |
| /* Generate the data set. */ |
| for (int i = 0; i < NUM_VECTORS; i++) { |
| vectors[i] = new DenseMatrix64F(numElements, 1); |
| for (int j = 0; j < numElements; j++) { |
| double v = mean[j] + std[j] * random.nextGaussian(); |
| vectors[i].set(j, v); |
| doubleArrays[i][j] = v; |
| floatArrays[i][j] = (float)v; |
| } |
| } |
| } |
| |
| public void testTrainerUsingEJML() { |
| GaussianModel gm = new GaussianModel(numElements, true); |
| for (int i = 0; i < NUM_VECTORS; i++) { |
| gm.update(vectors[i]); |
| } |
| gm.estimate(); |
| System.out.println(gm); |
| |
| double[] actualMean = gm.getMean(); |
| |
| for (int j = 0; j < mean.length; j++) { |
| |
| Assert.assertEquals("Assert mean.", mean[j], actualMean[j], std[j]); |
| } |
| } |
| |
| public void testTrainerUsingDoubleArray() { |
| GaussianModel gm = new GaussianModel(numElements, true); |
| for (int i = 0; i < NUM_VECTORS; i++) { |
| gm.update(doubleArrays[i]); |
| } |
| gm.estimate(); |
| System.out.println(gm); |
| |
| double[] actualMean = gm.getMean(); |
| |
| for (int j = 0; j < mean.length; j++) { |
| |
| Assert.assertEquals("Assert mean.", mean[j], actualMean[j], std[j]); |
| } |
| } |
| |
| public void testTrainerUsingFloatArray() { |
| GaussianModel gm = new GaussianModel(numElements, true); |
| for (int i = 0; i < NUM_VECTORS; i++) { |
| gm.update(floatArrays[i]); |
| } |
| gm.estimate(); |
| System.out.println(gm); |
| |
| double[] actualMean = gm.getMean(); |
| |
| for (int j = 0; j < mean.length; j++) { |
| |
| Assert.assertEquals("Assert mean.", mean[j], actualMean[j], std[j]); |
| } |
| } |
| } |