blob: b55265bc2f9d9c4f4bccaacd770a88e34ee64653 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math3.linear;
import java.util.Arrays;
import org.apache.commons.math3.util.FastMath;
import org.junit.Test;
import org.junit.Assert;
public class TriDiagonalTransformerTest {
private double[][] testSquare5 = {
{ 1, 2, 3, 1, 1 },
{ 2, 1, 1, 3, 1 },
{ 3, 1, 1, 1, 2 },
{ 1, 3, 1, 2, 1 },
{ 1, 1, 2, 1, 3 }
};
private double[][] testSquare3 = {
{ 1, 3, 4 },
{ 3, 2, 2 },
{ 4, 2, 0 }
};
@Test
public void testNonSquare() {
try {
new TriDiagonalTransformer(MatrixUtils.createRealMatrix(new double[3][2]));
Assert.fail("an exception should have been thrown");
} catch (NonSquareMatrixException ime) {
// expected behavior
}
}
@Test
public void testAEqualQTQt() {
checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare5));
checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare3));
}
private void checkAEqualQTQt(RealMatrix matrix) {
TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
RealMatrix q = transformer.getQ();
RealMatrix qT = transformer.getQT();
RealMatrix t = transformer.getT();
double norm = q.multiply(t).multiply(qT).subtract(matrix).getNorm();
Assert.assertEquals(0, norm, 4.0e-15);
}
@Test
public void testNoAccessBelowDiagonal() {
checkNoAccessBelowDiagonal(testSquare5);
checkNoAccessBelowDiagonal(testSquare3);
}
private void checkNoAccessBelowDiagonal(double[][] data) {
double[][] modifiedData = new double[data.length][];
for (int i = 0; i < data.length; ++i) {
modifiedData[i] = data[i].clone();
Arrays.fill(modifiedData[i], 0, i, Double.NaN);
}
RealMatrix matrix = MatrixUtils.createRealMatrix(modifiedData);
TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
RealMatrix q = transformer.getQ();
RealMatrix qT = transformer.getQT();
RealMatrix t = transformer.getT();
double norm = q.multiply(t).multiply(qT).subtract(MatrixUtils.createRealMatrix(data)).getNorm();
Assert.assertEquals(0, norm, 4.0e-15);
}
@Test
public void testQOrthogonal() {
checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQ());
checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQ());
}
@Test
public void testQTOrthogonal() {
checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQT());
checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQT());
}
private void checkOrthogonal(RealMatrix m) {
RealMatrix mTm = m.transpose().multiply(m);
RealMatrix id = MatrixUtils.createRealIdentityMatrix(mTm.getRowDimension());
Assert.assertEquals(0, mTm.subtract(id).getNorm(), 1.0e-15);
}
@Test
public void testTTriDiagonal() {
checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getT());
checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getT());
}
private void checkTriDiagonal(RealMatrix m) {
final int rows = m.getRowDimension();
final int cols = m.getColumnDimension();
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) {
if ((i < j - 1) || (i > j + 1)) {
Assert.assertEquals(0, m.getEntry(i, j), 1.0e-16);
}
}
}
}
@Test
public void testMatricesValues5() {
checkMatricesValues(testSquare5,
new double[][] {
{ 1.0, 0.0, 0.0, 0.0, 0.0 },
{ 0.0, -0.5163977794943222, 0.016748280772542083, 0.839800693771262, 0.16669620021405473 },
{ 0.0, -0.7745966692414833, -0.4354553000860955, -0.44989322880603355, -0.08930153582895772 },
{ 0.0, -0.2581988897471611, 0.6364346693566014, -0.30263204032131164, 0.6608313651342882 },
{ 0.0, -0.2581988897471611, 0.6364346693566009, -0.027289660803112598, -0.7263191580755246 }
},
new double[] { 1, 4.4, 1.433099579242636, -0.89537362758743, 2.062274048344794 },
new double[] { -FastMath.sqrt(15), -3.0832882879592476, 0.6082710842351517, 1.1786086405912128 });
}
@Test
public void testMatricesValues3() {
checkMatricesValues(testSquare3,
new double[][] {
{ 1.0, 0.0, 0.0 },
{ 0.0, -0.6, 0.8 },
{ 0.0, -0.8, -0.6 },
},
new double[] { 1, 2.64, -0.64 },
new double[] { -5, -1.52 });
}
private void checkMatricesValues(double[][] matrix, double[][] qRef,
double[] mainDiagnonal,
double[] secondaryDiagonal) {
TriDiagonalTransformer transformer =
new TriDiagonalTransformer(MatrixUtils.createRealMatrix(matrix));
// check values against known references
RealMatrix q = transformer.getQ();
Assert.assertEquals(0, q.subtract(MatrixUtils.createRealMatrix(qRef)).getNorm(), 1.0e-14);
RealMatrix t = transformer.getT();
double[][] tData = new double[mainDiagnonal.length][mainDiagnonal.length];
for (int i = 0; i < mainDiagnonal.length; ++i) {
tData[i][i] = mainDiagnonal[i];
if (i > 0) {
tData[i][i - 1] = secondaryDiagonal[i - 1];
}
if (i < secondaryDiagonal.length) {
tData[i][i + 1] = secondaryDiagonal[i];
}
}
Assert.assertEquals(0, t.subtract(MatrixUtils.createRealMatrix(tData)).getNorm(), 1.0e-14);
// check the same cached instance is returned the second time
Assert.assertTrue(q == transformer.getQ());
Assert.assertTrue(t == transformer.getT());
}
}