| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysml.udf.lib; |
| |
| import java.util.Arrays; |
| import java.util.Comparator; |
| |
| import org.apache.sysml.runtime.matrix.data.InputInfo; |
| import org.apache.sysml.runtime.matrix.data.MatrixBlock; |
| import org.apache.sysml.runtime.matrix.data.OutputInfo; |
| import org.apache.sysml.udf.FunctionParameter; |
| import org.apache.sysml.udf.Matrix; |
| import org.apache.sysml.udf.PackageFunction; |
| import org.apache.sysml.udf.Scalar; |
| import org.apache.sysml.udf.Matrix.ValueType; |
| |
| /** |
| * Wrapper class for Sorting and Creating of a Permutation Matrix |
| * |
| * Sort single-column matrix and produce a permutation matrix. Pre-multiplying |
| * the input matrix with the permutation matrix produces a sorted matrix. A |
| * permutation matrix is a matrix where each row and each column as exactly one |
| * 1: To From 1 |
| * |
| * Input: (n x 1)-matrix, and true/false for sorting in descending order Output: |
| * (n x n)- matrix |
| * |
| * permutation_matrix= externalFunction(Matrix[Double] A, Boolean desc) return |
| * (Matrix[Double] P) implemented in |
| * (classname="org.apache.sysml.udf.lib.PermutationMatrixWrapper" |
| * ,exectype="mem"); A = read( "Data/A.mtx"); P = permutation_matrix( A[,2], |
| * false); B = P %*% A |
| * |
| */ |
| @Deprecated |
| public class PermutationMatrixWrapper extends PackageFunction |
| { |
| |
| private static final long serialVersionUID = 1L; |
| private static final String OUTPUT_FILE = "TMP"; |
| |
| // return matrix |
| private Matrix _ret; |
| |
| @Override |
| public int getNumFunctionOutputs() { |
| return 1; |
| } |
| |
| @Override |
| public FunctionParameter getFunctionOutput(int pos) { |
| if (pos == 0) |
| return _ret; |
| |
| throw new RuntimeException( |
| "Invalid function output being requested"); |
| } |
| |
| @Override |
| public void execute() { |
| try { |
| Matrix inM = (Matrix) getFunctionInput(0); |
| double[][] inData = inM.getMatrixAsDoubleArray(); |
| boolean desc = Boolean.parseBoolean(((Scalar) getFunctionInput(1)) |
| .getValue()); |
| |
| // add index column as first column |
| double[][] idxData = new double[(int) inM.getNumRows()][2]; |
| for (int i = 0; i < idxData.length; i++) { |
| idxData[i][0] = i; |
| idxData[i][1] = inData[i][0]; |
| } |
| |
| // sort input matrix (in-place) |
| if (!desc) // asc |
| Arrays.sort(idxData, new AscRowComparator(1)); |
| else |
| // desc |
| Arrays.sort(idxData, new DescRowComparator(1)); |
| |
| // create and populate sparse matrixblock for result |
| MatrixBlock mb = new MatrixBlock(idxData.length, idxData.length, |
| true, idxData.length); |
| for (int i = 0; i < idxData.length; i++) { |
| mb.quickSetValue(i, (int) idxData[i][0], 1.0); |
| } |
| mb.examSparsity(); |
| |
| // set result |
| String dir = createOutputFilePathAndName(OUTPUT_FILE); |
| _ret = new Matrix(dir, mb.getNumRows(), mb.getNumColumns(), |
| ValueType.Double); |
| _ret.setMatrixDoubleArray(mb, OutputInfo.BinaryBlockOutputInfo, |
| InputInfo.BinaryBlockInputInfo); |
| } |
| catch (Exception e) { |
| throw new RuntimeException( |
| "Error executing external permutation_matrix function", e); |
| } |
| } |
| |
| /** |
| * |
| * |
| */ |
| private static class AscRowComparator implements Comparator<double[]> { |
| private int _col = -1; |
| |
| public AscRowComparator(int col) { |
| _col = col; |
| } |
| |
| @Override |
| public int compare(double[] arg0, double[] arg1) { |
| return (arg0[_col] < arg1[_col] ? -1 |
| : (arg0[_col] == arg1[_col] ? 0 : 1)); |
| } |
| } |
| |
| /** |
| * |
| * |
| */ |
| private static class DescRowComparator implements Comparator<double[]> { |
| private int _col = -1; |
| |
| public DescRowComparator(int col) { |
| _col = col; |
| } |
| |
| @Override |
| public int compare(double[] arg0, double[] arg1) { |
| return (arg0[_col] > arg1[_col] ? -1 |
| : (arg0[_col] == arg1[_col] ? 0 : 1)); |
| } |
| } |
| } |