blob: d48b12f496478285d779af428849ecc68fcaff05 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.stat.regression;
import org.apache.commons.math4.legacy.linear.Array2DRowRealMatrix;
import org.apache.commons.math4.legacy.linear.LUDecomposition;
import org.apache.commons.math4.legacy.linear.RealMatrix;
import org.apache.commons.math4.legacy.linear.RealVector;
/**
* The GLS implementation of multiple linear regression.
*
* GLS assumes a general covariance matrix Omega of the error
* <pre>
* u ~ N(0, Omega)
* </pre>
*
* Estimated by GLS,
* <pre>
* b=(X' Omega^-1 X)^-1X'Omega^-1 y
* </pre>
* whose variance is
* <pre>
* Var(b)=(X' Omega^-1 X)^-1
* </pre>
* @since 2.0
*/
public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
/** Covariance matrix. */
private RealMatrix omega;
/** Inverse of covariance matrix. */
private RealMatrix omegaInverse;
/** Replace sample data, overriding any previous sample.
* @param y y values of the sample
* @param x x values of the sample
* @param covariance array representing the covariance matrix
*/
public void newSampleData(double[] y, double[][] x, double[][] covariance) {
validateSampleData(x, y);
newYSampleData(y);
newXSampleData(x);
validateCovarianceData(x, covariance);
newCovarianceData(covariance);
}
/**
* Add the covariance data.
*
* @param omega the [n,n] array representing the covariance
*/
protected void newCovarianceData(double[][] omega){
this.omega = new Array2DRowRealMatrix(omega);
this.omegaInverse = null;
}
/**
* Get the inverse of the covariance.
* <p>The inverse of the covariance matrix is lazily evaluated and cached.</p>
* @return inverse of the covariance
*/
protected RealMatrix getOmegaInverse() {
if (omegaInverse == null) {
omegaInverse = new LUDecomposition(omega).getSolver().getInverse();
}
return omegaInverse;
}
/**
* Calculates beta by GLS.
* <pre>
* b=(X' Omega^-1 X)^-1X'Omega^-1 y
* </pre>
* @return beta
*/
@Override
protected RealVector calculateBeta() {
RealMatrix oi = getOmegaInverse();
RealMatrix xt = getX().transpose();
RealMatrix xtoix = xt.multiply(oi).multiply(getX());
RealMatrix inverse = new LUDecomposition(xtoix).getSolver().getInverse();
return inverse.multiply(xt).multiply(oi).operate(getY());
}
/**
* Calculates the variance on the beta.
* <pre>
* Var(b)=(X' Omega^-1 X)^-1
* </pre>
* @return The beta variance matrix
*/
@Override
protected RealMatrix calculateBetaVariance() {
RealMatrix oi = getOmegaInverse();
RealMatrix xtoix = getX().transpose().multiply(oi).multiply(getX());
return new LUDecomposition(xtoix).getSolver().getInverse();
}
/**
* Calculates the estimated variance of the error term using the formula
* <pre>
* Var(u) = Tr(u' Omega^-1 u)/(n-k)
* </pre>
* where n and k are the row and column dimensions of the design
* matrix X.
*
* @return error variance
* @since 2.2
*/
@Override
protected double calculateErrorVariance() {
RealVector residuals = calculateResiduals();
double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
return t / (getX().getRowDimension() - getX().getColumnDimension());
}
}