| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.commons.math4.legacy.stat.correlation; |
| |
| import java.util.Arrays; |
| import java.util.Comparator; |
| |
| import org.apache.commons.math4.legacy.exception.DimensionMismatchException; |
| import org.apache.commons.math4.legacy.linear.BlockRealMatrix; |
| import org.apache.commons.math4.legacy.linear.MatrixUtils; |
| import org.apache.commons.math4.legacy.linear.RealMatrix; |
| import org.apache.commons.math4.core.jdkmath.JdkMath; |
| import org.apache.commons.math4.legacy.core.Pair; |
| |
| /** |
| * Implementation of Kendall's Tau-b rank correlation. |
| * <p> |
| * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and |
| * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if |
| * x<sub>1</sub> < x<sub>2</sub> and y<sub>1</sub> < y<sub>2</sub> |
| * or x<sub>2</sub> < x<sub>1</sub> and y<sub>2</sub> < y<sub>1</sub>. |
| * The pair is <i>discordant</i> if x<sub>1</sub> < x<sub>2</sub> and |
| * y<sub>2</sub> < y<sub>1</sub> or x<sub>2</sub> < x<sub>1</sub> and |
| * y<sub>1</sub> < y<sub>2</sub>. If either x<sub>1</sub> = x<sub>2</sub> |
| * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor |
| * discordant. |
| * <p> |
| * Kendall's Tau-b is defined as: |
| * <div style="white-space: pre"><code> |
| * tau<sub>b</sub> = (n<sub>c</sub> - n<sub>d</sub>) / sqrt((n<sub>0</sub> - n<sub>1</sub>) * (n<sub>0</sub> - n<sub>2</sub>)) |
| * </code></div> |
| * <p> |
| * where: |
| * <ul> |
| * <li>n<sub>0</sub> = n * (n - 1) / 2</li> |
| * <li>n<sub>c</sub> = Number of concordant pairs</li> |
| * <li>n<sub>d</sub> = Number of discordant pairs</li> |
| * <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li> |
| * <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li> |
| * <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li> |
| * <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li> |
| * </ul> |
| * <p> |
| * This implementation uses the O(n log n) algorithm described in |
| * William R. Knight's 1966 paper "A Computer Method for Calculating |
| * Kendall's Tau with Ungrouped Data" in the Journal of the American |
| * Statistical Association. |
| * |
| * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient"> |
| * Kendall tau rank correlation coefficient (Wikipedia)</a> |
| * @see <a href="http://www.jstor.org/stable/2282833">A Computer |
| * Method for Calculating Kendall's Tau with Ungrouped Data</a> |
| * |
| * @since 3.3 |
| */ |
| public class KendallsCorrelation { |
| |
| /** correlation matrix. */ |
| private final RealMatrix correlationMatrix; |
| |
| /** |
| * Create a KendallsCorrelation instance without data. |
| */ |
| public KendallsCorrelation() { |
| correlationMatrix = null; |
| } |
| |
| /** |
| * Create a KendallsCorrelation from a rectangular array |
| * whose columns represent values of variables to be correlated. |
| * |
| * @param data rectangular array with columns representing variables |
| * @throws IllegalArgumentException if the input data array is not |
| * rectangular with at least two rows and two columns. |
| */ |
| public KendallsCorrelation(double[][] data) { |
| this(MatrixUtils.createRealMatrix(data)); |
| } |
| |
| /** |
| * Create a KendallsCorrelation from a RealMatrix whose columns |
| * represent variables to be correlated. |
| * |
| * @param matrix matrix with columns representing variables to correlate |
| */ |
| public KendallsCorrelation(RealMatrix matrix) { |
| correlationMatrix = computeCorrelationMatrix(matrix); |
| } |
| |
| /** |
| * Returns the correlation matrix. |
| * |
| * @return correlation matrix |
| */ |
| public RealMatrix getCorrelationMatrix() { |
| return correlationMatrix; |
| } |
| |
| /** |
| * Computes the Kendall's Tau rank correlation matrix for the columns of |
| * the input matrix. |
| * |
| * @param matrix matrix with columns representing variables to correlate |
| * @return correlation matrix |
| */ |
| public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) { |
| int nVars = matrix.getColumnDimension(); |
| RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars); |
| for (int i = 0; i < nVars; i++) { |
| for (int j = 0; j < i; j++) { |
| double corr = correlation(matrix.getColumn(i), matrix.getColumn(j)); |
| outMatrix.setEntry(i, j, corr); |
| outMatrix.setEntry(j, i, corr); |
| } |
| outMatrix.setEntry(i, i, 1d); |
| } |
| return outMatrix; |
| } |
| |
| /** |
| * Computes the Kendall's Tau rank correlation matrix for the columns of |
| * the input rectangular array. The columns of the array represent values |
| * of variables to be correlated. |
| * |
| * @param matrix matrix with columns representing variables to correlate |
| * @return correlation matrix |
| */ |
| public RealMatrix computeCorrelationMatrix(final double[][] matrix) { |
| return computeCorrelationMatrix(new BlockRealMatrix(matrix)); |
| } |
| |
| /** |
| * Computes the Kendall's Tau rank correlation coefficient between the two arrays. |
| * |
| * @param xArray first data array |
| * @param yArray second data array |
| * @return Returns Kendall's Tau rank correlation coefficient for the two arrays |
| * @throws DimensionMismatchException if the arrays lengths do not match |
| */ |
| public double correlation(final double[] xArray, final double[] yArray) |
| throws DimensionMismatchException { |
| |
| if (xArray.length != yArray.length) { |
| throw new DimensionMismatchException(xArray.length, yArray.length); |
| } |
| |
| final int n = xArray.length; |
| final long numPairs = sum(n - 1); |
| |
| @SuppressWarnings("unchecked") |
| Pair<Double, Double>[] pairs = new Pair[n]; |
| for (int i = 0; i < n; i++) { |
| pairs[i] = new Pair<>(xArray[i], yArray[i]); |
| } |
| |
| Arrays.sort(pairs, new Comparator<Pair<Double, Double>>() { |
| /** {@inheritDoc} */ |
| @Override |
| public int compare(Pair<Double, Double> pair1, Pair<Double, Double> pair2) { |
| int compareFirst = pair1.getFirst().compareTo(pair2.getFirst()); |
| return compareFirst != 0 ? compareFirst : pair1.getSecond().compareTo(pair2.getSecond()); |
| } |
| }); |
| |
| long tiedXPairs = 0; |
| long tiedXYPairs = 0; |
| long consecutiveXTies = 1; |
| long consecutiveXYTies = 1; |
| Pair<Double, Double> prev = pairs[0]; |
| for (int i = 1; i < n; i++) { |
| final Pair<Double, Double> curr = pairs[i]; |
| if (curr.getFirst().equals(prev.getFirst())) { |
| consecutiveXTies++; |
| if (curr.getSecond().equals(prev.getSecond())) { |
| consecutiveXYTies++; |
| } else { |
| tiedXYPairs += sum(consecutiveXYTies - 1); |
| consecutiveXYTies = 1; |
| } |
| } else { |
| tiedXPairs += sum(consecutiveXTies - 1); |
| consecutiveXTies = 1; |
| tiedXYPairs += sum(consecutiveXYTies - 1); |
| consecutiveXYTies = 1; |
| } |
| prev = curr; |
| } |
| tiedXPairs += sum(consecutiveXTies - 1); |
| tiedXYPairs += sum(consecutiveXYTies - 1); |
| |
| long swaps = 0; |
| @SuppressWarnings("unchecked") |
| Pair<Double, Double>[] pairsDestination = new Pair[n]; |
| for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) { |
| for (int offset = 0; offset < n; offset += 2 * segmentSize) { |
| int i = offset; |
| final int iEnd = JdkMath.min(i + segmentSize, n); |
| int j = iEnd; |
| final int jEnd = JdkMath.min(j + segmentSize, n); |
| |
| int copyLocation = offset; |
| while (i < iEnd || j < jEnd) { |
| if (i < iEnd) { |
| if (j < jEnd) { |
| if (pairs[i].getSecond().compareTo(pairs[j].getSecond()) <= 0) { |
| pairsDestination[copyLocation] = pairs[i]; |
| i++; |
| } else { |
| pairsDestination[copyLocation] = pairs[j]; |
| j++; |
| swaps += iEnd - i; |
| } |
| } else { |
| pairsDestination[copyLocation] = pairs[i]; |
| i++; |
| } |
| } else { |
| pairsDestination[copyLocation] = pairs[j]; |
| j++; |
| } |
| copyLocation++; |
| } |
| } |
| final Pair<Double, Double>[] pairsTemp = pairs; |
| pairs = pairsDestination; |
| pairsDestination = pairsTemp; |
| } |
| |
| long tiedYPairs = 0; |
| long consecutiveYTies = 1; |
| prev = pairs[0]; |
| for (int i = 1; i < n; i++) { |
| final Pair<Double, Double> curr = pairs[i]; |
| if (curr.getSecond().equals(prev.getSecond())) { |
| consecutiveYTies++; |
| } else { |
| tiedYPairs += sum(consecutiveYTies - 1); |
| consecutiveYTies = 1; |
| } |
| prev = curr; |
| } |
| tiedYPairs += sum(consecutiveYTies - 1); |
| |
| final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps; |
| final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs); |
| return concordantMinusDiscordant / JdkMath.sqrt(nonTiedPairsMultiplied); |
| } |
| |
| /** |
| * Returns the sum of the number from 1 .. n according to Gauss' summation formula: |
| * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \] |
| * |
| * @param n the summation end |
| * @return the sum of the number from 1 to n |
| */ |
| private static long sum(long n) { |
| return n * (n + 1) / 2L; |
| } |
| } |