| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.instructions.spark; |
| |
| import java.util.ArrayList; |
| import java.util.Iterator; |
| |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.spark.api.java.JavaPairRDD; |
| import org.apache.spark.api.java.function.Function; |
| import org.apache.spark.api.java.function.PairFlatMapFunction; |
| import org.apache.spark.api.java.function.PairFunction; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.common.Types.ValueType; |
| import org.apache.sysds.runtime.DMLRuntimeException; |
| import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; |
| import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; |
| import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; |
| import org.apache.sysds.runtime.functionobjects.DiagIndex; |
| import org.apache.sysds.runtime.functionobjects.RevIndex; |
| import org.apache.sysds.runtime.functionobjects.SortIndex; |
| import org.apache.sysds.runtime.functionobjects.SwapIndex; |
| import org.apache.sysds.runtime.instructions.InstructionUtils; |
| import org.apache.sysds.runtime.instructions.cp.CPOperand; |
| import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; |
| import org.apache.sysds.runtime.instructions.spark.functions.FilterDiagMatrixBlocksFunction; |
| import org.apache.sysds.runtime.instructions.spark.functions.IsBlockInList; |
| import org.apache.sysds.runtime.instructions.spark.functions.IsBlockInRange; |
| import org.apache.sysds.runtime.instructions.spark.functions.ReorgMapFunction; |
| import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils; |
| import org.apache.sysds.runtime.instructions.spark.utils.RDDSortUtils; |
| import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils; |
| import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; |
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| import org.apache.sysds.runtime.matrix.data.MatrixIndexes; |
| import org.apache.sysds.runtime.matrix.operators.Operator; |
| import org.apache.sysds.runtime.matrix.operators.ReorgOperator; |
| import org.apache.sysds.runtime.meta.DataCharacteristics; |
| import org.apache.sysds.runtime.util.DataConverter; |
| import org.apache.sysds.runtime.util.IndexRange; |
| import org.apache.sysds.runtime.util.UtilFunctions; |
| |
| import scala.Tuple2; |
| |
| public class ReorgSPInstruction extends UnarySPInstruction { |
| private static final Log LOG = LogFactory.getLog(ReorgSPInstruction.class.getName()); |
| |
| // sort-specific attributes (to enable variable attributes) |
| private CPOperand _col = null; |
| private CPOperand _desc = null; |
| private CPOperand _ixret = null; |
| private boolean _bSortIndInMem = false; |
| |
| private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) { |
| super(SPType.Reorg, op, in, out, opcode, istr); |
| } |
| |
| private ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand desc, CPOperand ixret, CPOperand out, |
| String opcode, boolean bSortIndInMem, String istr) { |
| this(op, in, out, opcode, istr); |
| _col = col; |
| _desc = desc; |
| _ixret = ixret; |
| _bSortIndInMem = bSortIndInMem; |
| } |
| |
| public static ReorgSPInstruction parseInstruction ( String str ) { |
| CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); |
| CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); |
| String opcode = InstructionUtils.getOpCode(str); |
| |
| if ( opcode.equalsIgnoreCase("r'") ) { |
| parseUnaryInstruction(str, in, out); //max 2 operands |
| return new ReorgSPInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, str); |
| } |
| else if ( opcode.equalsIgnoreCase("rev") ) { |
| parseUnaryInstruction(str, in, out); //max 2 operands |
| return new ReorgSPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str); |
| } |
| else if ( opcode.equalsIgnoreCase("rdiag") ) { |
| parseUnaryInstruction(str, in, out); //max 2 operands |
| return new ReorgSPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str); |
| } |
| else if ( opcode.equalsIgnoreCase("rsort") ) { |
| String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); |
| InstructionUtils.checkNumFields(parts, 5, 6); |
| in.split(parts[1]); |
| out.split(parts[5]); |
| CPOperand col = new CPOperand(parts[2]); |
| CPOperand desc = new CPOperand(parts[3]); |
| CPOperand ixret = new CPOperand(parts[4]); |
| boolean bSortIndInMem = false; |
| |
| if(parts.length > 5) |
| bSortIndInMem = Boolean.parseBoolean(parts[6]); |
| |
| return new ReorgSPInstruction(new ReorgOperator(new SortIndex(1,false,false)), |
| in, col, desc, ixret, out, opcode, bSortIndInMem, str); |
| } |
| else { |
| throw new DMLRuntimeException("Unknown opcode while parsing a ReorgInstruction: " + str); |
| } |
| } |
| |
| @Override |
| public void processInstruction(ExecutionContext ec) { |
| SparkExecutionContext sec = (SparkExecutionContext)ec; |
| String opcode = getOpcode(); |
| |
| //get input rdd handle |
| JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() ); |
| JavaPairRDD<MatrixIndexes,MatrixBlock> out = null; |
| DataCharacteristics mcIn = sec.getDataCharacteristics(input1.getName()); |
| |
| if( opcode.equalsIgnoreCase("r'") ) //TRANSPOSE |
| { |
| //execute transpose reorg operation |
| out = in1.mapToPair(new ReorgMapFunction(opcode)); |
| } |
| else if( opcode.equalsIgnoreCase("rev") ) //REVERSE |
| { |
| //execute reverse reorg operation |
| out = in1.flatMapToPair(new RDDRevFunction(mcIn)); |
| if( mcIn.getRows() % mcIn.getBlocksize() != 0 ) |
| out = RDDAggregateUtils.mergeByKey(out, false); |
| } |
| else if ( opcode.equalsIgnoreCase("rdiag") ) // DIAG |
| { |
| if(mcIn.getCols() == 1) { // diagV2M |
| out = in1.flatMapToPair(new RDDDiagV2MFunction(mcIn)); |
| } |
| else { // diagM2V |
| //execute diagM2V operation |
| out = in1.filter(new FilterDiagMatrixBlocksFunction()) |
| .mapToPair(new ReorgMapFunction(opcode)); |
| } |
| } |
| else if ( opcode.equalsIgnoreCase("rsort") ) //ORDER |
| { |
| // Sort by column 'col' in ascending/descending order and return either index/value |
| |
| //get parameters |
| long[] cols = _col.getDataType().isMatrix() ? DataConverter.convertToLongVector(ec.getMatrixInput(_col.getName())) : |
| new long[]{ec.getScalarInput(_col).getLongValue()}; |
| boolean desc = ec.getScalarInput(_desc).getBooleanValue(); |
| boolean ixret = ec.getScalarInput(_ixret).getBooleanValue(); |
| boolean singleCol = (mcIn.getCols() == 1); |
| out = in1; |
| |
| if( cols.length > mcIn.getBlocksize() ) |
| LOG.warn("Unsupported sort with number of order-by columns large than blocksize: "+cols.length); |
| |
| if( singleCol || cols.length==1 ) { |
| // extract column (if necessary) and sort |
| if( !singleCol ) |
| out = out.filter(new IsBlockInRange(1, mcIn.getRows(), cols[0], cols[0], mcIn)) |
| .mapValues(new ExtractColumn(UtilFunctions.computeCellInBlock(cols[0], mcIn.getBlocksize()))); |
| |
| //actual index/data sort operation |
| if( ixret ) //sort indexes |
| out = RDDSortUtils.sortIndexesByVal(out, !desc, mcIn.getRows(), mcIn.getBlocksize()); |
| else if( singleCol && !desc) //sort single-column matrix |
| out = RDDSortUtils.sortByVal(out, mcIn.getRows(), mcIn.getBlocksize()); |
| else if( !_bSortIndInMem ) //sort multi-column matrix w/ rewrite |
| out = RDDSortUtils.sortDataByVal(out, in1, !desc, mcIn.getRows(), mcIn.getCols(), mcIn.getBlocksize()); |
| else //sort multi-column matrix |
| out = RDDSortUtils.sortDataByValMemSort(out, in1, !desc, mcIn.getRows(), mcIn.getCols(), mcIn.getBlocksize(), sec, (ReorgOperator) _optr); |
| } |
| else { //general case: multi-column sort |
| // extract columns (if necessary) |
| if( cols.length < mcIn.getCols() ) |
| out = out.filter(new IsBlockInList(cols, mcIn)) |
| .mapToPair(new ExtractColumns(cols, mcIn)); |
| |
| // append extracted columns (if necessary) |
| if( mcIn.getCols() > mcIn.getBlocksize() ) |
| out = RDDAggregateUtils.mergeByKey(out); |
| |
| //actual index/data sort operation |
| if( ixret ) //sort indexes |
| out = RDDSortUtils.sortIndexesByVals(out, !desc, mcIn.getRows(), cols.length, mcIn.getBlocksize()); |
| else if( cols.length==mcIn.getCols() && !desc) //sort single-column matrix |
| out = RDDSortUtils.sortByVals(out, mcIn.getRows(), cols.length, mcIn.getBlocksize()); |
| else //sort multi-column matrix |
| out = RDDSortUtils.sortDataByVals(out, in1, !desc, mcIn.getRows(), |
| mcIn.getCols(), cols.length, mcIn.getBlocksize()); |
| } |
| } |
| else { |
| throw new DMLRuntimeException("Error: Incorrect opcode in ReorgSPInstruction:" + opcode); |
| } |
| |
| //store output rdd handle |
| if( opcode.equalsIgnoreCase("rsort") && _col.getDataType().isMatrix() ) |
| sec.releaseMatrixInput(_col.getName()); |
| updateReorgDataCharacteristics(sec); |
| sec.setRDDHandleForVariable(output.getName(), out); |
| sec.addLineageRDD(output.getName(), input1.getName()); |
| } |
| |
| private void updateReorgDataCharacteristics(SparkExecutionContext sec) { |
| DataCharacteristics mc1 = sec.getDataCharacteristics(input1.getName()); |
| DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName()); |
| |
| //infer initially unknown dimensions from inputs |
| if( !mcOut.dimsKnown() ) |
| { |
| if( !mc1.dimsKnown() ) |
| throw new DMLRuntimeException("Unable to compute output matrix characteristics from input."); |
| |
| if ( getOpcode().equalsIgnoreCase("r'") ) |
| mcOut.set(mc1.getCols(), mc1.getRows(), mc1.getBlocksize(), mc1.getBlocksize()); |
| else if ( getOpcode().equalsIgnoreCase("rdiag") ) |
| mcOut.set(mc1.getRows(), (mc1.getCols()>1)?1:mc1.getRows(), mc1.getBlocksize(), mc1.getBlocksize()); |
| else if ( getOpcode().equalsIgnoreCase("rsort") ) { |
| boolean ixret = sec.getScalarInput(_ixret).getBooleanValue(); |
| mcOut.set(mc1.getRows(), ixret?1:mc1.getCols(), mc1.getBlocksize(), mc1.getBlocksize()); |
| } |
| } |
| |
| //infer initially unknown nnz from input |
| if( !mcOut.nnzKnown() && mc1.nnzKnown() ){ |
| boolean sortIx = getOpcode().equalsIgnoreCase("rsort") && sec.getScalarInput(_ixret.getName(), _ixret.getValueType(), _ixret.isLiteral()).getBooleanValue(); |
| if( sortIx ) |
| mcOut.setNonZeros(mc1.getRows()); |
| else //default (r', rdiag, rsort data) |
| mcOut.setNonZeros(mc1.getNonZeros()); |
| } |
| } |
| |
| private static class RDDDiagV2MFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = 31065772250744103L; |
| |
| private ReorgOperator _reorgOp = null; |
| private DataCharacteristics _mcIn = null; |
| |
| public RDDDiagV2MFunction(DataCharacteristics mcIn) { |
| _reorgOp = new ReorgOperator(DiagIndex.getDiagIndexFnObject()); |
| _mcIn = mcIn; |
| } |
| |
| @Override |
| public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) { |
| ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<>(); |
| MatrixIndexes ixIn = arg0._1(); |
| MatrixBlock blkIn = arg0._2(); |
| |
| //compute output indexes and reorg data |
| long rix = ixIn.getRowIndex(); |
| MatrixIndexes ixOut = new MatrixIndexes(rix, rix); |
| MatrixBlock blkOut = blkIn.reorgOperations(_reorgOp, new MatrixBlock(), -1, -1, -1); |
| ret.add(new Tuple2<>(ixOut,blkOut)); |
| |
| // insert newly created empty blocks for entire row |
| int numBlocks = (int) Math.ceil((double)_mcIn.getRows()/_mcIn.getBlocksize()); |
| for(int i = 1; i <= numBlocks; i++) { |
| if(i != ixOut.getColumnIndex()) { |
| int lrlen = UtilFunctions.computeBlockSize(_mcIn.getRows(), rix, _mcIn.getBlocksize()); |
| int lclen = UtilFunctions.computeBlockSize(_mcIn.getRows(), i, _mcIn.getBlocksize()); |
| MatrixBlock emptyBlk = new MatrixBlock(lrlen, lclen, true); |
| ret.add(new Tuple2<>(new MatrixIndexes(rix, i), emptyBlk)); |
| } |
| } |
| |
| return ret.iterator(); |
| } |
| } |
| |
| private static class RDDRevFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = 1183373828539843938L; |
| |
| private DataCharacteristics _mcIn = null; |
| |
| public RDDRevFunction(DataCharacteristics mcIn) { |
| _mcIn = mcIn; |
| } |
| |
| @Override |
| public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) { |
| //construct input |
| IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(arg0); |
| |
| //execute reverse operation |
| ArrayList<IndexedMatrixValue> out = new ArrayList<>(); |
| LibMatrixReorg.rev(in, _mcIn.getRows(), _mcIn.getBlocksize(), out); |
| |
| //construct output |
| return SparkUtils.fromIndexedMatrixBlock(out).iterator(); |
| } |
| } |
| |
| private static class ExtractColumn implements Function<MatrixBlock, MatrixBlock> |
| { |
| private static final long serialVersionUID = -1472164797288449559L; |
| |
| private int _col; |
| |
| public ExtractColumn(int col) { |
| _col = col; |
| } |
| |
| @Override |
| public MatrixBlock call(MatrixBlock arg0) |
| throws Exception |
| { |
| return arg0.slice(0, arg0.getNumRows()-1, _col, _col, new MatrixBlock()); |
| } |
| } |
| |
| private static class ExtractColumns implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>,MatrixIndexes,MatrixBlock> |
| { |
| private static final long serialVersionUID = 2902729186431711506L; |
| |
| private final long[] _cols; |
| private final int _blen; |
| |
| public ExtractColumns(long[] cols, DataCharacteristics mc) { |
| _cols = cols; |
| _blen = mc.getBlocksize(); |
| } |
| |
| @Override |
| public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) { |
| MatrixIndexes ix = arg0._1(); |
| MatrixBlock in = arg0._2(); |
| MatrixBlock out = new MatrixBlock(in.getNumRows(), _cols.length, true); |
| for(int i=0; i<_cols.length; i++) |
| if( UtilFunctions.isInBlockRange(ix, _blen, new IndexRange(1, Long.MAX_VALUE, _cols[i], _cols[i])) ) { |
| int index = UtilFunctions.computeCellInBlock(_cols[i], _blen); |
| out.leftIndexingOperations(in.slice(0, in.getNumRows()-1, index, index, new MatrixBlock()), |
| 0, in.getNumRows()-1, i, i, out, UpdateType.INPLACE); |
| } |
| return new Tuple2<>(new MatrixIndexes(ix.getRowIndex(), 1), out); |
| } |
| } |
| } |
| |