blob: 468eab60aabd9ef8fee9edb2904ef998aea5bda7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with this
* work for additional information regarding copyright ownership. The ASF
* licenses this file to You under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law
* or agreed to in writing, software distributed under the License is
* distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the specific language
* governing permissions and limitations under the License.
*/
package org.apache.commons.math3.filter;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.MatrixDimensionMismatchException;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Precision;
import org.junit.Assert;
import org.junit.Test;
/**
* Tests for {@link KalmanFilter}.
*
*/
public class KalmanFilterTest {
@Test(expected=MatrixDimensionMismatchException.class)
public void testTransitionMeasurementMatrixMismatch() {
// A and H matrix do not match in dimensions
// A = [ 1 ]
RealMatrix A = new Array2DRowRealMatrix(new double[] { 1d });
// no control input
RealMatrix B = null;
// H = [ 1 1 ]
RealMatrix H = new Array2DRowRealMatrix(new double[] { 1d, 1d });
// Q = [ 0 ]
RealMatrix Q = new Array2DRowRealMatrix(new double[] { 0 });
// R = [ 0 ]
RealMatrix R = new Array2DRowRealMatrix(new double[] { 0 });
ProcessModel pm
= new DefaultProcessModel(A, B, Q,
new ArrayRealVector(new double[] { 0 }), null);
MeasurementModel mm = new DefaultMeasurementModel(H, R);
new KalmanFilter(pm, mm);
Assert.fail("transition and measurement matrix should not be compatible");
}
@Test(expected=MatrixDimensionMismatchException.class)
public void testTransitionControlMatrixMismatch() {
// A and B matrix do not match in dimensions
// A = [ 1 ]
RealMatrix A = new Array2DRowRealMatrix(new double[] { 1d });
// B = [ 1 1 ]
RealMatrix B = new Array2DRowRealMatrix(new double[] { 1d, 1d });
// H = [ 1 ]
RealMatrix H = new Array2DRowRealMatrix(new double[] { 1d });
// Q = [ 0 ]
RealMatrix Q = new Array2DRowRealMatrix(new double[] { 0 });
// R = [ 0 ]
RealMatrix R = new Array2DRowRealMatrix(new double[] { 0 });
ProcessModel pm
= new DefaultProcessModel(A, B, Q,
new ArrayRealVector(new double[] { 0 }), null);
MeasurementModel mm = new DefaultMeasurementModel(H, R);
new KalmanFilter(pm, mm);
Assert.fail("transition and control matrix should not be compatible");
}
@Test
public void testConstant() {
// simulates a simple process with a constant state and no control input
double constantValue = 10d;
double measurementNoise = 0.1d;
double processNoise = 1e-5d;
// A = [ 1 ]
RealMatrix A = new Array2DRowRealMatrix(new double[] { 1d });
// no control input
RealMatrix B = null;
// H = [ 1 ]
RealMatrix H = new Array2DRowRealMatrix(new double[] { 1d });
// x = [ 10 ]
RealVector x = new ArrayRealVector(new double[] { constantValue });
// Q = [ 1e-5 ]
RealMatrix Q = new Array2DRowRealMatrix(new double[] { processNoise });
// R = [ 0.1 ]
RealMatrix R = new Array2DRowRealMatrix(new double[] { measurementNoise });
ProcessModel pm
= new DefaultProcessModel(A, B, Q,
new ArrayRealVector(new double[] { constantValue }), null);
MeasurementModel mm = new DefaultMeasurementModel(H, R);
KalmanFilter filter = new KalmanFilter(pm, mm);
Assert.assertEquals(1, filter.getMeasurementDimension());
Assert.assertEquals(1, filter.getStateDimension());
assertMatrixEquals(Q.getData(), filter.getErrorCovariance());
// check the initial state
double[] expectedInitialState = new double[] { constantValue };
assertVectorEquals(expectedInitialState, filter.getStateEstimation());
RealVector pNoise = new ArrayRealVector(1);
RealVector mNoise = new ArrayRealVector(1);
RandomGenerator rand = new JDKRandomGenerator();
// iterate 60 steps
for (int i = 0; i < 60; i++) {
filter.predict();
// Simulate the process
pNoise.setEntry(0, processNoise * rand.nextGaussian());
// x = A * x + p_noise
x = A.operate(x).add(pNoise);
// Simulate the measurement
mNoise.setEntry(0, measurementNoise * rand.nextGaussian());
// z = H * x + m_noise
RealVector z = H.operate(x).add(mNoise);
filter.correct(z);
// state estimate shouldn't be larger than measurement noise
double diff = FastMath.abs(constantValue - filter.getStateEstimation()[0]);
// System.out.println(diff);
Assert.assertTrue(Precision.compareTo(diff, measurementNoise, 1e-6) < 0);
}
// error covariance should be already very low (< 0.02)
Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[0][0],
0.02d, 1e-6) < 0);
}
@Test
public void testConstantAcceleration() {
// simulates a vehicle, accelerating at a constant rate (0.1 m/s)
// discrete time interval
double dt = 0.1d;
// position measurement noise (meter)
double measurementNoise = 10d;
// acceleration noise (meter/sec^2)
double accelNoise = 0.2d;
// A = [ 1 dt ]
// [ 0 1 ]
RealMatrix A = new Array2DRowRealMatrix(new double[][] { { 1, dt }, { 0, 1 } });
// B = [ dt^2/2 ]
// [ dt ]
RealMatrix B = new Array2DRowRealMatrix(
new double[][] { { FastMath.pow(dt, 2d) / 2d }, { dt } });
// H = [ 1 0 ]
RealMatrix H = new Array2DRowRealMatrix(new double[][] { { 1d, 0d } });
// x = [ 0 0 ]
RealVector x = new ArrayRealVector(new double[] { 0, 0 });
RealMatrix tmp = new Array2DRowRealMatrix(
new double[][] { { FastMath.pow(dt, 4d) / 4d, FastMath.pow(dt, 3d) / 2d },
{ FastMath.pow(dt, 3d) / 2d, FastMath.pow(dt, 2d) } });
// Q = [ dt^4/4 dt^3/2 ]
// [ dt^3/2 dt^2 ]
RealMatrix Q = tmp.scalarMultiply(FastMath.pow(accelNoise, 2));
// P0 = [ 1 1 ]
// [ 1 1 ]
RealMatrix P0 = new Array2DRowRealMatrix(new double[][] { { 1, 1 }, { 1, 1 } });
// R = [ measurementNoise^2 ]
RealMatrix R = new Array2DRowRealMatrix(
new double[] { FastMath.pow(measurementNoise, 2) });
// constant control input, increase velocity by 0.1 m/s per cycle
RealVector u = new ArrayRealVector(new double[] { 0.1d });
ProcessModel pm = new DefaultProcessModel(A, B, Q, x, P0);
MeasurementModel mm = new DefaultMeasurementModel(H, R);
KalmanFilter filter = new KalmanFilter(pm, mm);
Assert.assertEquals(1, filter.getMeasurementDimension());
Assert.assertEquals(2, filter.getStateDimension());
assertMatrixEquals(P0.getData(), filter.getErrorCovariance());
// check the initial state
double[] expectedInitialState = new double[] { 0.0, 0.0 };
assertVectorEquals(expectedInitialState, filter.getStateEstimation());
RandomGenerator rand = new JDKRandomGenerator();
RealVector tmpPNoise = new ArrayRealVector(
new double[] { FastMath.pow(dt, 2d) / 2d, dt });
// iterate 60 steps
for (int i = 0; i < 60; i++) {
filter.predict(u);
// Simulate the process
RealVector pNoise = tmpPNoise.mapMultiply(accelNoise * rand.nextGaussian());
// x = A * x + B * u + pNoise
x = A.operate(x).add(B.operate(u)).add(pNoise);
// Simulate the measurement
double mNoise = measurementNoise * rand.nextGaussian();
// z = H * x + m_noise
RealVector z = H.operate(x).mapAdd(mNoise);
filter.correct(z);
// state estimate shouldn't be larger than the measurement noise
double diff = FastMath.abs(x.getEntry(0) - filter.getStateEstimation()[0]);
Assert.assertTrue(Precision.compareTo(diff, measurementNoise, 1e-6) < 0);
}
// error covariance of the velocity should be already very low (< 0.1)
Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[1][1],
0.1d, 1e-6) < 0);
}
/**
* Represents an idealized Cannonball only taking into account gravity.
*/
public static class Cannonball {
private final double[] gravity = { 0, -9.81 };
private final double[] velocity;
private final double[] location;
private double timeslice;
public Cannonball(double timeslice, double angle, double initialVelocity) {
this.timeslice = timeslice;
final double angleInRadians = FastMath.toRadians(angle);
this.velocity = new double[] {
initialVelocity * FastMath.cos(angleInRadians),
initialVelocity * FastMath.sin(angleInRadians)
};
this.location = new double[] { 0, 0 };
}
public double getX() {
return location[0];
}
public double getY() {
return location[1];
}
public double getXVelocity() {
return velocity[0];
}
public double getYVelocity() {
return velocity[1];
}
public void step() {
// break gravitational force into a smaller time slice.
double[] slicedGravity = gravity.clone();
for ( int i = 0; i < slicedGravity.length; i++ ) {
slicedGravity[i] *= timeslice;
}
// apply the acceleration to velocity.
double[] slicedVelocity = velocity.clone();
for ( int i = 0; i < velocity.length; i++ ) {
velocity[i] += slicedGravity[i];
slicedVelocity[i] = velocity[i] * timeslice;
location[i] += slicedVelocity[i];
}
// cannonballs shouldn't go into the ground.
if ( location[1] < 0 ) {
location[1] = 0;
}
}
}
@Test
public void testCannonball() {
// simulates the flight of a cannonball (only taking gravity and initial thrust into account)
// number of iterations
final int iterations = 144;
// discrete time interval
final double dt = 0.1d;
// position measurement noise (meter)
final double measurementNoise = 30d;
// the initial velocity of the cannonball
final double initialVelocity = 100;
// shooting angle
final double angle = 45;
final Cannonball cannonball = new Cannonball(dt, angle, initialVelocity);
final double speedX = cannonball.getXVelocity();
final double speedY = cannonball.getYVelocity();
// A = [ 1, dt, 0, 0 ] => x(n+1) = x(n) + vx(n)
// [ 0, 1, 0, 0 ] => vx(n+1) = vx(n)
// [ 0, 0, 1, dt ] => y(n+1) = y(n) + vy(n)
// [ 0, 0, 0, 1 ] => vy(n+1) = vy(n)
final RealMatrix A = MatrixUtils.createRealMatrix(new double[][] {
{ 1, dt, 0, 0 },
{ 0, 1, 0, 0 },
{ 0, 0, 1, dt },
{ 0, 0, 0, 1 }
});
// The control vector, which adds acceleration to the kinematic equations.
// 0 => x(n+1) = x(n+1)
// 0 => vx(n+1) = vx(n+1)
// -9.81*dt^2 => y(n+1) = y(n+1) - 1/2 * 9.81 * dt^2
// -9.81*dt => vy(n+1) = vy(n+1) - 9.81 * dt
final RealVector controlVector =
MatrixUtils.createRealVector(new double[] { 0, 0, 0.5 * -9.81 * dt * dt, -9.81 * dt } );
// The control matrix B only expects y and vy, see control vector
final RealMatrix B = MatrixUtils.createRealMatrix(new double[][] {
{ 0, 0, 0, 0 },
{ 0, 0, 0, 0 },
{ 0, 0, 1, 0 },
{ 0, 0, 0, 1 }
});
// We only observe the x/y position of the cannonball
final RealMatrix H = MatrixUtils.createRealMatrix(new double[][] {
{ 1, 0, 0, 0 },
{ 0, 0, 0, 0 },
{ 0, 0, 1, 0 },
{ 0, 0, 0, 0 }
});
// our guess of the initial state.
final RealVector initialState = MatrixUtils.createRealVector(new double[] { 0, speedX, 0, speedY } );
// the initial error covariance matrix, the variance = noise^2
final double var = measurementNoise * measurementNoise;
final RealMatrix initialErrorCovariance = MatrixUtils.createRealMatrix(new double[][] {
{ var, 0, 0, 0 },
{ 0, 1e-3, 0, 0 },
{ 0, 0, var, 0 },
{ 0, 0, 0, 1e-3 }
});
// we assume no process noise -> zero matrix
final RealMatrix Q = MatrixUtils.createRealMatrix(4, 4);
// the measurement covariance matrix
final RealMatrix R = MatrixUtils.createRealMatrix(new double[][] {
{ var, 0, 0, 0 },
{ 0, 1e-3, 0, 0 },
{ 0, 0, var, 0 },
{ 0, 0, 0, 1e-3 }
});
final ProcessModel pm = new DefaultProcessModel(A, B, Q, initialState, initialErrorCovariance);
final MeasurementModel mm = new DefaultMeasurementModel(H, R);
final KalmanFilter filter = new KalmanFilter(pm, mm);
final RandomGenerator rng = new Well19937c(1000);
final NormalDistribution dist = new NormalDistribution(rng, 0, measurementNoise);
for (int i = 0; i < iterations; i++) {
// get the "real" cannonball position
double x = cannonball.getX();
double y = cannonball.getY();
// apply measurement noise to current cannonball position
double nx = x + dist.sample();
double ny = y + dist.sample();
cannonball.step();
filter.predict(controlVector);
// correct the filter with our measurements
filter.correct(new double[] { nx, 0, ny, 0 } );
// state estimate shouldn't be larger than the measurement noise
double diff = FastMath.abs(cannonball.getY() - filter.getStateEstimation()[2]);
Assert.assertTrue(Precision.compareTo(diff, measurementNoise, 1e-6) < 0);
}
// error covariance of the x/y-position should be already very low (< 3m std dev = 9 variance)
Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[0][0],
9, 1e-6) < 0);
Assert.assertTrue(Precision.compareTo(filter.getErrorCovariance()[2][2],
9, 1e-6) < 0);
}
private void assertVectorEquals(double[] expected, double[] result) {
Assert.assertEquals("Wrong number of rows.", expected.length,
result.length);
for (int i = 0; i < expected.length; i++) {
Assert.assertEquals("Wrong value at position [" + i + "]",
expected[i], result[i], 1.0e-6);
}
}
private void assertMatrixEquals(double[][] expected, double[][] result) {
Assert.assertEquals("Wrong number of rows.", expected.length,
result.length);
for (int i = 0; i < expected.length; i++) {
Assert.assertEquals("Wrong number of columns.", expected[i].length,
result[i].length);
for (int j = 0; j < expected[i].length; j++) {
Assert.assertEquals("Wrong value at position [" + i + "," + j
+ "]", expected[i][j], result[i][j], 1.0e-6);
}
}
}
}