blob: 09a471591a26314c008e52cdd4eac8a4182f88c9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.statistics.inference;
/**
* Provide support for square matrix basic algebraic operations.
*
* <p>Matrix element indexing is 0-based e.g. {@code get(0, 0)}
* returns the element in the first row, first column of the matrix.
*
* <p>This class supports computations in the {@link KolmogorovSmirnovTest}.
*
* @since 1.1
*/
final class SquareMatrixSupport {
/**
* Define a real-valued square matrix.
*
* <p>This matrix supports a scale to protect against overflow. The true value
* of any matrix value is multiplied by {@code 2^scale}. This is readily performed
* using {@link Math#scalb(double, int)}.
*/
interface RealSquareMatrix {
/**
* Gets the dimension for the rows and columns.
*
* @return the dimension
*/
int dimension();
/**
* Gets the scale of the matrix values.
* The true value is the value returned from {@link #get(int, int)} multiplied by
* {@code 2^scale}.
*
* @return the scale
*/
int scale();
/**
* Gets the value. This is a scaled value. The true value is the value returned
* multiplied by {@code 2^scale}.
*
* @param i Row
* @param j Column
* @return the value
* @see #scale
*/
double get(int i, int j);
/**
* Returns the result of multiplying {@code this} with itself {@code n} times.
*
* @param n raise {@code this} to power {@code n}
* @return {@code this^n}
* @throws IllegalArgumentException if {@code n < 0}
*/
RealSquareMatrix power(int n);
}
/**
* Implementation of {@link RealSquareMatrix} using a {@code double[]} array to
* store entries. Values are addressed using {@code i*dim + j} where {@code dim} is
* the square dimension.
*
* <p>Scaling is supported using the central element {@code [m][m]} where
* {@code m = dimension/2}. Scaling is only implemented post-multiplication
* to protect against overflow during repeat multiplication operations.
*
* <p>Note: The scaling is implemented to support computation of Kolmogorov's
* distribution as described in:
* <ul>
* <li>
* Marsaglia, G., Tsang, W. W., &amp; Wang, J. (2003).
* <a href="https://doi.org/10.18637/jss.v008.i18">Evaluating Kolmogorov's Distribution.</a>
* Journal of Statistical Software, 8(18), 1–4.
* </ul>
*/
private static class ArrayRealSquareMatrix implements RealSquareMatrix {
/** The scaling threshold. Marsaglia used 1e140. This uses 2^400 ~ 2.58e120 */
private static final double SCALE_THRESHOLD = 0x1.0p400;
/** Dimension. */
private final int dim;
/** Entries of the matrix. */
private final double[] data;
/** Matrix scale. */
private final int exp;
/**
* @param dimension Matrix dimension.
* @param data Matrix data.
* @param scale Matrix scale.
*/
ArrayRealSquareMatrix(int dimension, double[] data, int scale) {
this.dim = dimension;
this.data = data;
this.exp = scale;
}
@Override
public int dimension() {
return dim;
}
@Override
public int scale() {
return exp;
}
@Override
public double get(int i, int j) {
return data[i * dim + j];
}
@Override
public RealSquareMatrix power(int n) {
checkExponent(n);
if (n == 0) {
return identity();
}
if (n == 1) {
return this;
}
// Here at least 1 multiplication occurs.
// Compute the power by repeat squaring and multiplication:
// 13 = 1101
// x^13 = x^8 * x^4 * x^1
// = ((x^2 * x)^2)^2 * x
// 21 = 10101
// x^21 = x^16 * x^4 * x^1
// = (((x^2)^2 * x)^2)^2 * x
// 1. Find highest set bit in n
// 2. Initialise result as x
// 3. For remaining bits (0 or 1) below the highest set bit:
// - square the current result
// - if the current bit is 1 then multiply by x
// In this scheme we require 2 matrix array allocations and a column array.
// Working arrays
final double[] col = new double[dim];
double[] b = new double[data.length];
double[] tmp;
// Initialise result as A^1.
final double[] a = data;
final int ea = exp;
double[] r = a.clone();
int er = ea;
// Shift the highest set bit off the top.
// Any remaining bits are detected in the sign bit.
final int shift = Integer.numberOfLeadingZeros(n) + 1;
int bits = n << shift;
// Process remaining bits below highest set bit.
for (int i = 32 - shift; i != 0; i--, bits <<= 1) {
// Square the result
er = multiply(r, er, r, er, col, b);
// Recycle working array
tmp = b;
b = r;
r = tmp;
if (bits < 0) {
// Multiply by A
er = multiply(r, er, a, ea, col, b);
// Recycle working array
tmp = b;
b = r;
r = tmp;
}
}
return new ArrayRealSquareMatrix(dim, r, er);
}
/**
* Creates the identity matrix I with the same dimension as {@code this}.
*
* @return I
*/
private RealSquareMatrix identity() {
final int n = dimension();
return new RealSquareMatrix() {
@Override
public int dimension() {
return n;
}
@Override
public int scale() {
return 0;
}
@Override
public double get(int i, int j) {
return i == j ? 1 : 0;
}
@Override
public RealSquareMatrix power(int p) {
return this;
}
};
}
/**
* Returns the result of postmultiplying {@code a} by {@code b}. It is expected
* the scale of the result will be the sum of the scale of the arguments; this
* may be adjusted by the scale power if the result is scaled by a power of two
* for overflow protection.
*
* @param a Matrix.
* @param sa Scale of matrix a.
* @param b Matrix to postmultiply by.
* @param sb Scale of matrix b.
* @param col Working array for a column of the matrix.
* @param out Output {@code a * b}
* @return Scale of {@code a * b}
*/
private static int multiply(double[] a, int sa, double[] b, int sb, double[] col, double[] out) {
final int m = col.length;
// Rows are contiguous; Columns are non-contiguous
int k;
for (int c = 0; c < m; c++) {
// Extract column from b to contiguous memory
k = c;
for (int i = 0; i < m; i++, k += m) {
col[i] = b[k];
}
// row * col
k = 0;
for (int r = 0; r < m; r++) {
double sum = 0;
for (int i = 0; i < m; i++, k++) {
sum += a[k] * col[i];
}
out[r * m + c] = sum;
}
}
int s = sa + sb;
// Overflow protection. Ideally we would check all elements but for speed
// we check the central one only.
k = m >> 1;
if (out[k * m + k] > SCALE_THRESHOLD) {
// Downscale
// We could downscale by the inverse of SCALE_THRESHOLD.
// However this does not account for how far above the threshold
// the central element is. Here we downscale so the central element
// is roughly 1 allowing other elements to be larger and still protected
// from overflow.
final int exp = Math.getExponent(out[k * m + k]);
final double downScale = Math.scalb(1.0, -exp);
s += exp;
for (int i = 0; i < out.length; i++) {
out[i] *= downScale;
}
}
return s;
}
/**
* Check the exponent is not negative.
*
* @param p Exponent.
* @throws IllegalArgumentException if the exponent is negative
*/
private static void checkExponent(int p) {
if (p < 0) {
throw new IllegalArgumentException("Not positive exponent: " + p);
}
}
}
/** No instances. */
private SquareMatrixSupport() {}
/**
* Creates a square matrix. Data may be used in-place.
*
* <p>Values are addressed using {@code a[i][j] = i*dimension + j}.
*
* @param dimension Matrix dimension.
* @param data Matrix data.
* @return the square matrix
* @throws IllegalArgumentException if the matrix data is not square (length = dimension * dimension)
*/
static RealSquareMatrix create(int dimension, double[] data) {
if (dimension * dimension != data.length) {
// Note: %<d is 'relative indexing' to re-use the last argument
throw new IllegalArgumentException(String.format("Not square: %d * %<d != %d", dimension, data.length));
}
return new ArrayRealSquareMatrix(dimension, data, 0);
}
}