| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # Computes the matrix square root B of a matrix A, such that |
| # A = B %*% B. |
| # |
| # INPUT: |
| # ------------------------------------------------------------------------------ |
| # A Input Matrix A |
| # S Strategy (COMMON .. java-based commons-math, DML) |
| # ------------------------------------------------------------------------------ |
| # |
| # OUTPUT: |
| # ------------------------------------------------------------------------------ |
| # B Output Matrix B |
| # ------------------------------------------------------------------------------ |
| |
| |
| m_sqrtMatrix = function(Matrix[Double] A, String S) |
| return(Matrix[Double] B) |
| { |
| if (S == "COMMON") { |
| B = sqrtMatrixJava(A) |
| } else if (S == "DML") { |
| N = nrow(A); |
| D = ncol(A); |
| |
| #check that matrix is square |
| if (D != N){ |
| stop("matrixSqrt Input Error: matrix not square!") |
| } |
| |
| # Any non singualar square matrix has a square root |
| isDiag = isDiagonal(A) |
| if(isDiag) { |
| B = sqrtDiagMatrix(A); |
| } else { |
| [eValues, eVectors] = eigen(A); |
| |
| hasNonNegativeEigenValues = (sum(eValues >= 0) == length(eValues)); |
| |
| if(!hasNonNegativeEigenValues) { |
| stop("matrixSqrt exec Error: matrix has imaginary square root"); |
| } |
| |
| isSymmetric = sum(A == t(A)) == length(A); |
| allEigenValuesUnique = length(eValues) == length(unique(eValues)); |
| |
| if(allEigenValuesUnique | isSymmetric) { |
| # calculate X = VDV^(-1) -> S = sqrt(D) -> sqrt_x = VSV^(-1) |
| sqrtD = sqrtDiagMatrix(diag(eValues)); |
| V_Inv = inv(eVectors); |
| B = eVectors %*% sqrtD %*% V_Inv; |
| } else { |
| #formular: (Denman–Beavers iteration) |
| Y = A |
| #identity matrix |
| Z = diag(matrix(1.0, rows=N, cols=1)) |
| |
| for (x in 1:100) { |
| Y_new = (1 / 2) * (Y + inv(Z)) |
| Z_new = (1 / 2) * (Z + inv(Y)) |
| Y = Y_new |
| Z = Z_new |
| } |
| B = Y |
| } |
| } |
| } else { |
| stop("Error: Unknown strategy for matrix square root.") |
| } |
| } |
| |
| # assumes square and diagonal matrix |
| sqrtDiagMatrix = function(Matrix[Double] X) |
| return(Matrix[Double] sqrt_x) |
| { |
| N = nrow(X); |
| |
| #check if identity matrix |
| is_identity = sum(diag(diag(X)) == X)==length(X) |
| & sum(diag(X) == matrix(1,nrow(X),1))==nrow(X); |
| |
| if(is_identity) |
| sqrt_x = X; |
| else |
| sqrt_x = diag(sqrt(diag(X))); |
| } |
| |
| isDiagonal = function (Matrix[Double] X) |
| return(boolean diagonal) |
| { |
| #all cells should be the same to be diagonal |
| diagonal = sum(diag(diag(X)) == X) == length(X); |
| } |
| |