blob: c4ce0c6c82c848170183a4fb71890c0f55fc10cf [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.udf.lib;
import java.util.Arrays;
import java.util.Comparator;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.udf.FunctionParameter;
import org.apache.sysml.udf.Matrix;
import org.apache.sysml.udf.PackageFunction;
import org.apache.sysml.udf.Scalar;
import org.apache.sysml.udf.Matrix.ValueType;
/**
* Wrapper class for Sorting and Creating of a Permutation Matrix
*
* Sort single-column matrix and produce a permutation matrix. Pre-multiplying
* the input matrix with the permutation matrix produces a sorted matrix. A
* permutation matrix is a matrix where each row and each column as exactly one
* 1: To From 1
*
* Input: (n x 1)-matrix, and true/false for sorting in descending order Output:
* (n x n)- matrix
*
* permutation_matrix= externalFunction(Matrix[Double] A, Boolean desc) return
* (Matrix[Double] P) implemented in
* (classname="org.apache.sysml.udf.lib.PermutationMatrixWrapper"
* ,exectype="mem"); A = read( "Data/A.mtx"); P = permutation_matrix( A[,2],
* false); B = P %*% A
*
*/
@Deprecated
public class PermutationMatrixWrapper extends PackageFunction
{
private static final long serialVersionUID = 1L;
private static final String OUTPUT_FILE = "TMP";
// return matrix
private Matrix _ret;
@Override
public int getNumFunctionOutputs() {
return 1;
}
@Override
public FunctionParameter getFunctionOutput(int pos) {
if (pos == 0)
return _ret;
throw new RuntimeException(
"Invalid function output being requested");
}
@Override
public void execute() {
try {
Matrix inM = (Matrix) getFunctionInput(0);
double[][] inData = inM.getMatrixAsDoubleArray();
boolean desc = Boolean.parseBoolean(((Scalar) getFunctionInput(1))
.getValue());
// add index column as first column
double[][] idxData = new double[(int) inM.getNumRows()][2];
for (int i = 0; i < idxData.length; i++) {
idxData[i][0] = i;
idxData[i][1] = inData[i][0];
}
// sort input matrix (in-place)
if (!desc) // asc
Arrays.sort(idxData, new AscRowComparator(1));
else
// desc
Arrays.sort(idxData, new DescRowComparator(1));
// create and populate sparse matrixblock for result
MatrixBlock mb = new MatrixBlock(idxData.length, idxData.length,
true, idxData.length);
for (int i = 0; i < idxData.length; i++) {
mb.quickSetValue(i, (int) idxData[i][0], 1.0);
}
mb.examSparsity();
// set result
String dir = createOutputFilePathAndName(OUTPUT_FILE);
_ret = new Matrix(dir, mb.getNumRows(), mb.getNumColumns(),
ValueType.Double);
_ret.setMatrixDoubleArray(mb, OutputInfo.BinaryBlockOutputInfo,
InputInfo.BinaryBlockInputInfo);
}
catch (Exception e) {
throw new RuntimeException(
"Error executing external permutation_matrix function", e);
}
}
/**
*
*
*/
private static class AscRowComparator implements Comparator<double[]> {
private int _col = -1;
public AscRowComparator(int col) {
_col = col;
}
@Override
public int compare(double[] arg0, double[] arg1) {
return (arg0[_col] < arg1[_col] ? -1
: (arg0[_col] == arg1[_col] ? 0 : 1));
}
}
/**
*
*
*/
private static class DescRowComparator implements Comparator<double[]> {
private int _col = -1;
public DescRowComparator(int col) {
_col = col;
}
@Override
public int compare(double[] arg0, double[] arg1) {
return (arg0[_col] > arg1[_col] ? -1
: (arg0[_col] == arg1[_col] ? 0 : 1));
}
}
}