blob: 6b779bc5d04beaab50704f40255f37f3f1e3f306 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.spaceroots.mantissa.linalg;
import junit.framework.*;
public class GeneralSquareMatrixTest
extends TestCase {
public GeneralSquareMatrixTest(String name) {
super(name);
}
public void testDimensions() {
GeneralSquareMatrix m = new GeneralSquareMatrix(3);
assertTrue(m.getRows() == 3);
assertTrue(m.getColumns() == 3);
}
public void testInvalidDimensions() {
boolean gotIt;
gotIt = false;
try {
new GeneralSquareMatrix(0);
} catch(IllegalArgumentException e) {
gotIt = true;
}
assertTrue(gotIt);
gotIt = false;
try {
new GeneralSquareMatrix(-3, null);
} catch(IllegalArgumentException e) {
gotIt = true;
}
assertTrue(gotIt);
}
public void testElements() {
Matrix m = buildMatrix(5, new BilinearPattern(1.0, 0.01));
checkMatrix(m, new BilinearPattern(1.0, 0.01));
}
public void testCopy() {
GeneralSquareMatrix m1 = buildMatrix(5, new BilinearPattern(1.0, 0.01));
GeneralSquareMatrix m2 = new GeneralSquareMatrix(m1);
for (int i = 0; i < m1.getRows(); ++i) {
for (int j = 0; j < m1.getColumns(); ++j) {
m1.setElement(i, j, -1.0);
}
}
assertTrue(m2.getRows() == m1.getRows());
assertTrue(m2.getColumns() == m1.getColumns());
checkMatrix(m2, new BilinearPattern(1.0, 0.01));
}
public void testDuplicate() {
GeneralSquareMatrix m1 = buildMatrix(5, new BilinearPattern(1.0, 0.01));
Matrix m2 = m1.duplicate();
assertTrue(m2 instanceof GeneralSquareMatrix);
for (int i = 0; i < m1.getRows(); ++i) {
for (int j = 0; j < m1.getColumns(); ++j) {
m1.setElement(i, j, -1.0);
}
}
assertTrue(m2.getRows() == m1.getRows());
assertTrue(m2.getColumns() == m1.getColumns());
checkMatrix(m2, new BilinearPattern(1.0, 0.01));
}
public void testSelfAdd() {
GeneralSquareMatrix m1 = buildMatrix(5, new BilinearPattern(1.0, 0.01));
GeneralSquareMatrix m2 = buildMatrix(5, new BilinearPattern(2, -0.03));
m1.selfAdd(m2);
checkMatrix(m1, new BilinearPattern(3, -0.02));
}
public void testSelfSub() {
GeneralSquareMatrix m1 = buildMatrix(5, new BilinearPattern(1.0, 0.01));
GeneralSquareMatrix m2 = buildMatrix(5, new BilinearPattern(2, -0.03));
m1.selfSub(m2);
checkMatrix(m1, new BilinearPattern(-1, 0.04));
}
public void testDeterminant() {
GeneralSquareMatrix m1 = buildProblem1().a;
assertTrue(Math.abs(m1.getDeterminant(1.0e-10) - 6.0) < 1.0e-10);
GeneralSquareMatrix m2 = buildProblem2().a;
assertTrue(Math.abs(m2.getDeterminant(1.0e-10) + 0.9999999) < 1.0e-10);
GeneralSquareMatrix m3 = buildProblem3().a;
assertTrue(Math.abs(m3.getDeterminant(1.0e-10) - 0.0) < 1.0e-10);
}
public void testSolve()
throws SingularMatrixException {
LinearProblem p;
Matrix result;
p = buildProblem1();
result = p.a.solve(p.b, 1.0e-10);
checkSolve(p, result);
p = buildProblem2();
result = p.a.solve(p.b, 1.0e-10);
checkSolve(p, result);
try {
p = buildProblem3();
result = p.a.solve(p.b, 1.0e-10);
fail("got " + result + ", should have caught an exception");
} catch(SingularMatrixException e) {
// expected
} catch(Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
}
public void testInverse()
throws SingularMatrixException {
SquareMatrix a, inverse;
a = buildProblem1().a;
inverse = a.getInverse(1.0e-10);
checkMatrix(a.mul(inverse), new IdentityPattern());
a = buildProblem2().a;
inverse = a.getInverse(1.0e-10);
checkMatrix(a.mul(inverse), new IdentityPattern());
try {
a = buildProblem3().a;
inverse = a.getInverse(1.0e-10);
fail("got " + inverse + ", should have caught an exception");
} catch(SingularMatrixException e) {
// expected
} catch(Exception e) {
fail("wrong exception caught: " + e.getMessage());
}
}
public static Test suite() {
return new TestSuite(GeneralSquareMatrixTest.class);
}
private interface ElementPattern {
public double value(int i, int j);
}
private static class BilinearPattern implements ElementPattern {
public BilinearPattern(double coeffI, double coeffJ) {
this.coeffI = coeffI;
this.coeffJ = coeffJ;
}
public double value(int i, int j) {
return coeffI * i + coeffJ * j;
}
private final double coeffI;
private final double coeffJ;
}
private static class IdentityPattern implements ElementPattern {
public double value(int i, int j) {
return (i == j) ? 1.0 : 0.0;
}
}
public GeneralSquareMatrix buildMatrix(int order,
ElementPattern pattern) {
GeneralSquareMatrix m = new GeneralSquareMatrix(order);
for (int i = 0; i < m.getRows(); ++i) {
for (int j = 0; j < m.getColumns(); ++j){
m.setElement(i, j, pattern.value(i, j));
}
}
return m;
}
public void checkMatrix(Matrix m, ElementPattern pattern) {
for (int i = 0; i < m.getRows(); ++i) {
for (int j = 0; j < m.getColumns(); ++j) {
assertTrue(Math.abs(m.getElement(i, j) - pattern.value(i, j))
< 1.0e-10);
}
}
}
private static class LinearProblem {
public GeneralSquareMatrix a;
public Matrix x;
public Matrix b;
public LinearProblem(GeneralSquareMatrix a, Matrix x, Matrix b) {
this.a = a;
this.x = x;
this.b = b;
}
}
private LinearProblem buildProblem1() {
GeneralSquareMatrix a = new GeneralSquareMatrix(4);
a.setElement(0, 0, 2.0);
a.setElement(0, 1, 1.0);
a.setElement(0, 2, 0.0);
a.setElement(0, 3, 4.0);
a.setElement(1, 0, -4.0);
a.setElement(1, 1, -2.0);
a.setElement(1, 2, 3.0);
a.setElement(1, 3, -7.0);
a.setElement(2, 0, 4.0);
a.setElement(2, 1, 1.0);
a.setElement(2, 2, -2.0);
a.setElement(2, 3, 8.0);
a.setElement(3, 0, 0.0);
a.setElement(3, 1, -3.0);
a.setElement(3, 2, -12.0);
a.setElement(3, 3, -1.0);
GeneralMatrix x = new GeneralMatrix(4, 1);
x.setElement(0, 0, 3.0);
x.setElement(1, 0, 4.0);
x.setElement(2, 0, -1.0);
x.setElement(3, 0, -2.0);
GeneralMatrix b = new GeneralMatrix(4, 1);
b.setElement(0, 0, 2.0);
b.setElement(1, 0, -9.0);
b.setElement(2, 0, 2.0);
b.setElement(3, 0, 2.0);
return new LinearProblem(a, x, b);
}
private LinearProblem buildProblem2()
{
double epsilon = 1.0e-7;
GeneralSquareMatrix a = new GeneralSquareMatrix(2);
a.setElement(0, 0, epsilon);
a.setElement(0, 1, 1.0);
a.setElement(1, 0, 1.0);
a.setElement(1, 1, 1.0);
GeneralMatrix x = new GeneralMatrix(2, 2);
x.setElement(0, 0, 1.0 + epsilon);
x.setElement(1, 0, 1.0 - epsilon);
x.setElement(0, 1, epsilon);
x.setElement(1, 1, 1.0);
GeneralMatrix b = new GeneralMatrix(2, 2);
b.setElement(0, 0, 1.0 + epsilon * epsilon);
b.setElement(1, 0, 2.0);
b.setElement(0, 1, 1.0 + epsilon * epsilon);
b.setElement(1, 1, 1.0 + epsilon);
return new LinearProblem(a, x, b);
}
private LinearProblem buildProblem3 ()
{
GeneralSquareMatrix a = new GeneralSquareMatrix(3);
a.setElement(0, 0, 1.0);
a.setElement(0, 1, 2.0);
a.setElement(0, 1, -3.0);
a.setElement(1, 0, 2.0);
a.setElement(1, 1, 1.0);
a.setElement(1, 1, 3.0);
a.setElement(2, 0, -3.0);
a.setElement(2, 1, 0.0);
a.setElement(2, 1, -9.0);
GeneralMatrix x = new GeneralMatrix(3, 1);
GeneralMatrix b = new GeneralMatrix(3, 1);
return new LinearProblem(a, x, b);
}
private void checkSolve(LinearProblem p, Matrix result)
{
Matrix residual = p.a.mul(result).sub(p.b);
for (int i = 0; i < residual.getRows(); ++i) {
for (int j = 0; j < residual.getColumns(); ++j) {
assertTrue(Math.abs(residual.getElement(i, j)) < 1.0e-10);
}
}
for (int i = 0; i < result.getRows(); ++i) {
for (int j = 0; j < result.getColumns(); ++j) {
assertTrue(Math.abs(result.getElement(i, j) - p.x.getElement(i, j))
< 1.0e-10);
}
}
}
}