| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.instructions.spark; |
| |
| import org.apache.spark.api.java.JavaPairRDD; |
| import org.apache.spark.api.java.function.Function2; |
| import org.apache.spark.api.java.function.PairFlatMapFunction; |
| import org.apache.spark.api.java.function.PairFunction; |
| import org.apache.sysds.common.Types.CorrectionLocationType; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.hops.OptimizerUtils; |
| import org.apache.sysds.runtime.DMLRuntimeException; |
| import org.apache.sysds.runtime.codegen.CodegenUtils; |
| import org.apache.sysds.runtime.codegen.LibSpoofPrimitives; |
| import org.apache.sysds.runtime.codegen.SpoofCellwise; |
| import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp; |
| import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType; |
| import org.apache.sysds.runtime.codegen.SpoofMultiAggregate; |
| import org.apache.sysds.runtime.codegen.SpoofOperator; |
| import org.apache.sysds.runtime.codegen.SpoofOuterProduct; |
| import org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType; |
| import org.apache.sysds.runtime.codegen.SpoofRowwise; |
| import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType; |
| import org.apache.sysds.runtime.controlprogram.caching.CacheableData; |
| import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; |
| import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; |
| import org.apache.sysds.runtime.functionobjects.Builtin; |
| import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; |
| import org.apache.sysds.runtime.functionobjects.KahanPlus; |
| import org.apache.sysds.runtime.instructions.InstructionUtils; |
| import org.apache.sysds.runtime.instructions.cp.CPOperand; |
| import org.apache.sysds.runtime.instructions.cp.DoubleObject; |
| import org.apache.sysds.runtime.instructions.cp.ScalarObject; |
| import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast; |
| import org.apache.sysds.runtime.instructions.spark.functions.MapInputSignature; |
| import org.apache.sysds.runtime.instructions.spark.functions.MapJoinSignature; |
| import org.apache.sysds.runtime.instructions.spark.functions.ReplicateBlockFunction; |
| import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils; |
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| import org.apache.sysds.runtime.matrix.data.MatrixIndexes; |
| import org.apache.sysds.runtime.matrix.operators.AggregateOperator; |
| import org.apache.sysds.runtime.meta.DataCharacteristics; |
| import scala.Tuple2; |
| |
| import java.io.Serializable; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Iterator; |
| import java.util.LinkedList; |
| import java.util.List; |
| import java.util.stream.IntStream; |
| |
| public class SpoofSPInstruction extends SPInstruction { |
| private final Class<?> _class; |
| private final byte[] _classBytes; |
| private final CPOperand[] _in; |
| private final CPOperand _out; |
| |
| private SpoofSPInstruction(Class<?> cls, byte[] classBytes, CPOperand[] in, CPOperand out, String opcode, |
| String str) { |
| super(SPType.SpoofFused, opcode, str); |
| _class = cls; |
| _classBytes = classBytes; |
| _in = in; |
| _out = out; |
| } |
| |
| public static SpoofSPInstruction parseInstruction(String str) { |
| String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); |
| |
| //String opcode = parts[0]; |
| ArrayList<CPOperand> inlist = new ArrayList<>(); |
| Class<?> cls = CodegenUtils.getClass(parts[2]); |
| byte[] classBytes = CodegenUtils.getClassData(parts[2]); |
| String opcode = parts[0] + CodegenUtils.createInstance(cls).getSpoofType(); |
| |
| for( int i=3; i<parts.length-2; i++ ) |
| inlist.add(new CPOperand(parts[i])); |
| CPOperand out = new CPOperand(parts[parts.length-2]); |
| //note: number of threads parts[parts.length-1] always ignored |
| |
| return new SpoofSPInstruction(cls, classBytes, inlist.toArray(new CPOperand[0]), out, opcode, str); |
| } |
| |
| @Override |
| public void processInstruction(ExecutionContext ec) { |
| SparkExecutionContext sec = (SparkExecutionContext)ec; |
| |
| //decide upon broadcast side inputs |
| boolean[] bcVect = determineBroadcastInputs(sec, _in); |
| boolean[] bcVect2 = getMatrixBroadcastVector(sec, _in, bcVect); |
| int main = getMainInputIndex(_in, bcVect); |
| |
| //create joined input rdd w/ replication if needed |
| DataCharacteristics mcIn = sec.getDataCharacteristics(_in[main].getName()); |
| JavaPairRDD<MatrixIndexes, MatrixBlock[]> in = createJoinedInputRDD( |
| sec, _in, bcVect, (_class.getSuperclass() == SpoofOuterProduct.class)); |
| JavaPairRDD<MatrixIndexes, MatrixBlock> out = null; |
| |
| //create lists of input broadcasts and scalars |
| ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new ArrayList<>(); |
| ArrayList<ScalarObject> scalars = new ArrayList<>(); |
| for( int i=0; i<_in.length; i++ ) { |
| if( _in[i].getDataType()==DataType.MATRIX && bcVect[i] ) { |
| bcMatrices.add(sec.getBroadcastForVariable(_in[i].getName())); |
| } |
| else if(_in[i].getDataType()==DataType.SCALAR) { |
| //note: even if literal, it might be compiled as scalar placeholder |
| scalars.add(sec.getScalarInput(_in[i])); |
| } |
| } |
| |
| //execute generated operator |
| if(_class.getSuperclass() == SpoofCellwise.class) //CELL |
| { |
| SpoofCellwise op = (SpoofCellwise) CodegenUtils.createInstance(_class); |
| AggregateOperator aggop = getAggregateOperator(op.getAggOp()); |
| |
| if( _out.getDataType()==DataType.MATRIX ) { |
| //execute codegen block operation |
| out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), |
| _classBytes, bcVect2, bcMatrices, scalars, mcIn.getBlocksize()), true); |
| |
| if( (op.getCellType()==CellType.ROW_AGG && mcIn.getCols() > mcIn.getBlocksize()) |
| || (op.getCellType()==CellType.COL_AGG && mcIn.getRows() > mcIn.getBlocksize())) { |
| long numBlocks = (op.getCellType()==CellType.ROW_AGG ) ? |
| mcIn.getNumRowBlocks() : mcIn.getNumColBlocks(); |
| out = RDDAggregateUtils.aggByKeyStable(out, aggop, |
| (int)Math.min(out.getNumPartitions(), numBlocks), false); |
| } |
| sec.setRDDHandleForVariable(_out.getName(), out); |
| |
| //maintain lineage info and output characteristics |
| maintainLineageInfo(sec, _in, bcVect, _out); |
| updateOutputDataCharacteristics(sec, op); |
| } |
| else { //SCALAR |
| out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), |
| _classBytes, bcVect2, bcMatrices, scalars, mcIn.getBlocksize()), true); |
| MatrixBlock tmpMB = RDDAggregateUtils.aggStable(out, aggop); |
| sec.setVariable(_out.getName(), new DoubleObject(tmpMB.getValue(0, 0))); |
| } |
| } |
| else if(_class.getSuperclass() == SpoofMultiAggregate.class) //MAGG |
| { |
| SpoofMultiAggregate op = (SpoofMultiAggregate) CodegenUtils.createInstance(_class); |
| AggOp[] aggOps = op.getAggOps(); |
| MatrixBlock tmpMB = in.mapToPair(new MultiAggregateFunction(_class.getName(), |
| _classBytes, bcVect2, bcMatrices, scalars, mcIn.getBlocksize())) |
| .values().fold(new MatrixBlock(), new MultiAggAggregateFunction(aggOps) ); |
| sec.setMatrixOutput(_out.getName(), tmpMB); |
| } |
| else if(_class.getSuperclass() == SpoofOuterProduct.class) //OUTER |
| { |
| if( _out.getDataType()==DataType.MATRIX ) { |
| SpoofOperator op = CodegenUtils.createInstance(_class); |
| OutProdType type = ((SpoofOuterProduct)op).getOuterProdType(); |
| |
| //update matrix characteristics |
| updateOutputDataCharacteristics(sec, op); |
| DataCharacteristics mcOut = sec.getDataCharacteristics(_out.getName()); |
| |
| out = in.mapPartitionsToPair(new OuterProductFunction( |
| _class.getName(), _classBytes, bcVect2, bcMatrices, scalars), true); |
| if(type == OutProdType.LEFT_OUTER_PRODUCT || type == OutProdType.RIGHT_OUTER_PRODUCT ) { |
| long numBlocks = mcOut.getNumRowBlocks() * mcOut.getNumColBlocks(); |
| out = RDDAggregateUtils.sumByKeyStable(out, |
| (int)Math.min(out.getNumPartitions(), numBlocks), false); |
| } |
| sec.setRDDHandleForVariable(_out.getName(), out); |
| |
| //maintain lineage info and output characteristics |
| maintainLineageInfo(sec, _in, bcVect, _out); |
| } |
| else { |
| out = in.mapPartitionsToPair(new OuterProductFunction( |
| _class.getName(), _classBytes, bcVect2, bcMatrices, scalars), true); |
| MatrixBlock tmp = RDDAggregateUtils.sumStable(out); |
| sec.setVariable(_out.getName(), new DoubleObject(tmp.getValue(0, 0))); |
| } |
| } |
| else if( _class.getSuperclass() == SpoofRowwise.class ) { //ROW |
| if( mcIn.getCols() > mcIn.getBlocksize() ) { |
| throw new DMLRuntimeException("Invalid spark rowwise operator w/ ncol=" + |
| mcIn.getCols()+", ncolpb="+mcIn.getBlocksize()+"."); |
| } |
| SpoofRowwise op = (SpoofRowwise) CodegenUtils.createInstance(_class); |
| long clen2 = op.getRowType().isConstDim2(op.getConstDim2()) ? op.getConstDim2() : |
| op.getRowType().isRowTypeB1() ? sec.getDataCharacteristics(_in[1].getName()).getCols() : -1; |
| RowwiseFunction fmmc = new RowwiseFunction(_class.getName(), _classBytes, bcVect2, |
| bcMatrices, scalars, mcIn.getBlocksize(), (int)mcIn.getCols(), (int)clen2); |
| out = in.mapPartitionsToPair(fmmc, op.getRowType()==RowType.ROW_AGG |
| || op.getRowType() == RowType.NO_AGG); |
| |
| if( op.getRowType().isColumnAgg() || op.getRowType()==RowType.FULL_AGG ) { |
| MatrixBlock tmpMB = RDDAggregateUtils.sumStable(out); |
| if( op.getRowType().isColumnAgg() ) |
| sec.setMatrixOutput(_out.getName(), tmpMB); |
| else |
| sec.setScalarOutput(_out.getName(), |
| new DoubleObject(tmpMB.quickGetValue(0, 0))); |
| } |
| else //row-agg or no-agg |
| { |
| if( op.getRowType()==RowType.ROW_AGG && mcIn.getCols() > mcIn.getBlocksize() ) { |
| out = RDDAggregateUtils.sumByKeyStable(out, |
| (int)Math.min(out.getNumPartitions(), mcIn.getNumRowBlocks()), false); |
| } |
| sec.setRDDHandleForVariable(_out.getName(), out); |
| |
| //maintain lineage info and output characteristics |
| maintainLineageInfo(sec, _in, bcVect, _out); |
| updateOutputDataCharacteristics(sec, op); |
| } |
| } |
| else { |
| throw new DMLRuntimeException("Operator " + _class.getSuperclass() + " is not supported on Spark"); |
| } |
| } |
| |
| private static boolean[] determineBroadcastInputs(SparkExecutionContext sec, CPOperand[] inputs) { |
| boolean[] ret = new boolean[inputs.length]; |
| double localBudget = OptimizerUtils.getLocalMemBudget() |
| - CacheableData.getBroadcastSize(); //account for other broadcasts |
| double bcBudget = SparkExecutionContext.getBroadcastMemoryBudget(); |
| |
| //decided for each matrix input if it fits into remaining memory |
| //budget; the major input, i.e., inputs[0] is always an RDD |
| for( int i=0; i<inputs.length; i++ ) |
| if( inputs[i].getDataType().isMatrix() ) { |
| DataCharacteristics mc = sec.getDataCharacteristics(inputs[i].getName()); |
| double sizeL = OptimizerUtils.estimateSizeExactSparsity(mc); |
| double sizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(mc); |
| //account for partitioning and local/remote budgets |
| ret[i] = localBudget > (sizeL + sizeP) && bcBudget > sizeP; |
| localBudget -= ret[i] ? sizeP : 0; //in local block manager |
| bcBudget -= ret[i] ? sizeP : 0; //in remote block managers |
| } |
| |
| //ensure there is at least one RDD input, with awareness for scalars |
| if( !IntStream.range(0, ret.length).anyMatch(i -> inputs[i].isMatrix() && !ret[i]) ) |
| ret[0] = false; |
| |
| return ret; |
| } |
| |
| private static boolean[] getMatrixBroadcastVector(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect) { |
| int numMtx = (int) Arrays.stream(inputs) |
| .filter(in -> in.getDataType().isMatrix()).count(); |
| boolean[] ret = new boolean[numMtx]; |
| for(int i=0, pos=0; i<inputs.length; i++) |
| if( inputs[i].getDataType().isMatrix() ) |
| ret[pos++] = bcVect[i]; |
| return ret; |
| } |
| |
| private static JavaPairRDD<MatrixIndexes, MatrixBlock[]> createJoinedInputRDD(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect, boolean outer) { |
| //get input rdd for main input |
| int main = getMainInputIndex(inputs, bcVect); |
| DataCharacteristics mcIn = sec.getDataCharacteristics(inputs[main].getName()); |
| JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable(inputs[main].getName()); |
| JavaPairRDD<MatrixIndexes, MatrixBlock[]> ret = in.mapValues(new MapInputSignature()); |
| |
| for( int i=0; i<inputs.length; i++ ) |
| if( i != main && inputs[i].getDataType().isMatrix() && !bcVect[i] ) { |
| //create side input rdd |
| String varname = inputs[i].getName(); |
| JavaPairRDD<MatrixIndexes, MatrixBlock> tmp = sec |
| .getBinaryMatrixBlockRDDHandleForVariable(varname); |
| DataCharacteristics mcTmp = sec.getDataCharacteristics(varname); |
| //replicate blocks if mismatch with main input |
| if( outer && i==2 ) |
| tmp = tmp.flatMapToPair(new ReplicateRightFactorFunction(mcIn.getRows(), mcIn.getBlocksize())); |
| else if( mcIn.getNumRowBlocks() > mcTmp.getNumRowBlocks() ) |
| tmp = tmp.flatMapToPair(new ReplicateBlockFunction(mcIn.getRows(), mcIn.getBlocksize(), false)); |
| else if( mcIn.getNumColBlocks() > mcTmp.getNumColBlocks() ) |
| tmp = tmp.flatMapToPair(new ReplicateBlockFunction(mcIn.getCols(), mcIn.getBlocksize(), true)); |
| //join main and side inputs and consolidate signature |
| ret = ret.join(tmp) |
| .mapValues(new MapJoinSignature()); |
| } |
| |
| return ret; |
| } |
| |
| private static void maintainLineageInfo(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect, CPOperand output) { |
| //add lineage info for all rdd/broadcast inputs |
| for( int i=0; i<inputs.length; i++ ) |
| if( inputs[i].getDataType().isMatrix() ) |
| sec.addLineage(output.getName(), inputs[i].getName(), bcVect[i]); |
| } |
| |
| private static int getMainInputIndex(CPOperand[] inputs, boolean[] bcVect) { |
| return IntStream.range(0, bcVect.length) |
| .filter(i -> inputs[i].isMatrix() && !bcVect[i]).min().orElse(0); |
| } |
| |
| private void updateOutputDataCharacteristics(SparkExecutionContext sec, SpoofOperator op) { |
| if(op instanceof SpoofCellwise) { |
| DataCharacteristics mcIn = sec.getDataCharacteristics(_in[0].getName()); |
| DataCharacteristics mcOut = sec.getDataCharacteristics(_out.getName()); |
| if( ((SpoofCellwise)op).getCellType()==CellType.ROW_AGG ) |
| mcOut.set(mcIn.getRows(), 1, mcIn.getBlocksize(), mcIn.getBlocksize()); |
| else if( ((SpoofCellwise)op).getCellType()==CellType.NO_AGG ) |
| mcOut.set(mcIn.getRows(), mcIn.getCols(), mcIn.getBlocksize(), mcIn.getBlocksize()); |
| } |
| else if(op instanceof SpoofOuterProduct) { |
| DataCharacteristics mcIn1 = sec.getDataCharacteristics(_in[0].getName()); //X |
| DataCharacteristics mcIn2 = sec.getDataCharacteristics(_in[1].getName()); //U |
| DataCharacteristics mcIn3 = sec.getDataCharacteristics(_in[2].getName()); //V |
| DataCharacteristics mcOut = sec.getDataCharacteristics(_out.getName()); |
| OutProdType type = ((SpoofOuterProduct)op).getOuterProdType(); |
| |
| if( type == OutProdType.CELLWISE_OUTER_PRODUCT) |
| mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getBlocksize(), mcIn1.getBlocksize()); |
| else if( type == OutProdType.LEFT_OUTER_PRODUCT) |
| mcOut.set(mcIn3.getRows(), mcIn3.getCols(), mcIn3.getBlocksize(), mcIn3.getBlocksize()); |
| else if( type == OutProdType.RIGHT_OUTER_PRODUCT ) |
| mcOut.set(mcIn2.getRows(), mcIn2.getCols(), mcIn2.getBlocksize(), mcIn2.getBlocksize()); |
| } |
| else if(op instanceof SpoofRowwise) { |
| DataCharacteristics mcIn = sec.getDataCharacteristics(_in[0].getName()); |
| DataCharacteristics mcOut = sec.getDataCharacteristics(_out.getName()); |
| RowType type = ((SpoofRowwise)op).getRowType(); |
| if( type == RowType.NO_AGG ) |
| mcOut.set(mcIn); |
| else if( type == RowType.ROW_AGG ) |
| mcOut.set(mcIn.getRows(), 1, |
| mcIn.getBlocksize(), mcIn.getBlocksize()); |
| else if( type == RowType.COL_AGG ) |
| mcOut.set(1, mcIn.getCols(), mcIn.getBlocksize(), mcIn.getBlocksize()); |
| else if( type == RowType.COL_AGG_T ) |
| mcOut.set(mcIn.getCols(), 1, mcIn.getBlocksize(), mcIn.getBlocksize()); |
| } |
| } |
| |
| private static class SpoofFunction implements Serializable |
| { |
| private static final long serialVersionUID = 2953479427746463003L; |
| |
| protected final boolean[] _bcInd; |
| protected final ArrayList<PartitionedBroadcast<MatrixBlock>> _inputs; |
| protected final ArrayList<ScalarObject> _scalars; |
| protected final byte[] _classBytes; |
| protected final String _className; |
| |
| protected SpoofFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) { |
| _bcInd = bcInd; |
| _inputs = bcMatrices; |
| _scalars = scalars; |
| _classBytes = classBytes; |
| _className = className; |
| } |
| |
| protected ArrayList<MatrixBlock> getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn) { |
| return getAllMatrixInputs(ixIn, blkIn, false); |
| } |
| |
| protected ArrayList<MatrixBlock> getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn, boolean outer) { |
| ArrayList<MatrixBlock> ret = new ArrayList<>(); |
| //add all rdd/broadcast inputs (main and side inputs) |
| for( int i=0, posRdd=0, posBc=0; i<_bcInd.length; i++ ) { |
| if( _bcInd[i] ) { |
| PartitionedBroadcast<MatrixBlock> pb = _inputs.get(posBc++); |
| int rowIndex = (int) ((outer && i==2) ? ixIn.getColumnIndex() : |
| (pb.getNumRowBlocks()>=ixIn.getRowIndex())?ixIn.getRowIndex():1); |
| int colIndex = (int) ((outer && i==2) ? 1 : |
| (pb.getNumColumnBlocks()>=ixIn.getColumnIndex())?ixIn.getColumnIndex():1); |
| ret.add(pb.getBlock(rowIndex, colIndex)); |
| } |
| else |
| ret.add(blkIn[posRdd++]); |
| } |
| return ret; |
| } |
| } |
| |
| private static class RowwiseFunction extends SpoofFunction |
| implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = -7926980450209760212L; |
| |
| private final int _blen; |
| private final int _clen; |
| private final int _clen2; |
| private SpoofRowwise _op = null; |
| |
| public RowwiseFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, |
| ArrayList<ScalarObject> scalars, int blen, int clen, int clen2) { |
| super(className, classBytes, bcInd, bcMatrices, scalars); |
| _blen = blen; |
| _clen = clen; |
| _clen2 = clen; |
| } |
| |
| @Override |
| public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg ) { |
| //lazy load of shipped class |
| if( _op == null ) { |
| Class<?> loadedClass = CodegenUtils.getClassSync(_className, _classBytes); |
| _op = (SpoofRowwise) CodegenUtils.createInstance(loadedClass); |
| } |
| |
| //setup local memory for reuse |
| LibSpoofPrimitives.setupThreadLocalMemory(_op.getNumIntermediates(), _clen, _clen2); |
| |
| ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new ArrayList<>(); |
| boolean aggIncr = (_op.getRowType().isColumnAgg() //aggregate entire partition |
| || _op.getRowType() == RowType.FULL_AGG); |
| MatrixBlock blkOut = aggIncr ? new MatrixBlock() : null; |
| |
| while( arg.hasNext() ) { |
| //get main input block and indexes |
| Tuple2<MatrixIndexes,MatrixBlock[]> e = arg.next(); |
| MatrixIndexes ixIn = e._1(); |
| MatrixBlock[] blkIn = e._2(); |
| long rix = (ixIn.getRowIndex()-1) * _blen; //0-based |
| |
| //prepare output and execute single-threaded operator |
| ArrayList<MatrixBlock> inputs = getAllMatrixInputs(ixIn, blkIn); |
| blkOut = aggIncr ? blkOut : new MatrixBlock(); |
| blkOut = _op.execute(inputs, _scalars, blkOut, false, aggIncr, rix); |
| if( !aggIncr ) { |
| MatrixIndexes ixOut = new MatrixIndexes(ixIn.getRowIndex(), |
| _op.getRowType()!=RowType.NO_AGG ? 1 : ixIn.getColumnIndex()); |
| ret.add(new Tuple2<>(ixOut, blkOut)); |
| } |
| } |
| |
| //cleanup and final result preparations |
| LibSpoofPrimitives.cleanupThreadLocalMemory(); |
| if( aggIncr ) { |
| blkOut.recomputeNonZeros(); |
| blkOut.examSparsity(); //deferred format change |
| ret.add(new Tuple2<>(new MatrixIndexes(1,1), blkOut)); |
| } |
| |
| return ret.iterator(); |
| } |
| } |
| |
| private static class CellwiseFunction extends SpoofFunction |
| implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = -8209188316939435099L; |
| |
| private SpoofCellwise _op = null; |
| private final int _blen; |
| |
| public CellwiseFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars, int blen) { |
| super(className, classBytes, bcInd, bcMatrices, scalars); |
| _blen = blen; |
| } |
| |
| @Override |
| public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg) |
| throws Exception |
| { |
| //lazy load of shipped class |
| if( _op == null ) { |
| Class<?> loadedClass = CodegenUtils.getClassSync(_className, _classBytes); |
| _op = (SpoofCellwise) CodegenUtils.createInstance(loadedClass); |
| } |
| |
| List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<>(); |
| while(arg.hasNext()) |
| { |
| Tuple2<MatrixIndexes,MatrixBlock[]> tmp = arg.next(); |
| MatrixIndexes ixIn = tmp._1(); |
| MatrixBlock[] blkIn = tmp._2(); |
| MatrixIndexes ixOut = ixIn; |
| MatrixBlock blkOut = new MatrixBlock(); |
| ArrayList<MatrixBlock> inputs = getAllMatrixInputs(ixIn, blkIn); |
| long rix = (ixIn.getRowIndex()-1) * _blen; //0-based |
| |
| //execute core operation |
| if( _op.getCellType()==CellType.FULL_AGG ) { |
| ScalarObject obj = _op.execute(inputs, _scalars, 1, rix); |
| blkOut.reset(1, 1); |
| blkOut.quickSetValue(0, 0, obj.getDoubleValue()); |
| } |
| else { |
| if( _op.getCellType()==CellType.ROW_AGG ) |
| ixOut = new MatrixIndexes(ixOut.getRowIndex(), 1); |
| else if(_op.getCellType()==CellType.COL_AGG) |
| ixOut = new MatrixIndexes(1, ixOut.getColumnIndex()); |
| blkOut = _op.execute(inputs, _scalars, blkOut, 1, rix); |
| } |
| ret.add(new Tuple2<>(ixOut, blkOut)); |
| } |
| return ret.iterator(); |
| } |
| } |
| |
| private static class MultiAggregateFunction extends SpoofFunction |
| implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = -5224519291577332734L; |
| |
| private SpoofMultiAggregate _op = null; |
| private final int _blen; |
| |
| public MultiAggregateFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars, int blen) { |
| super(className, classBytes, bcInd, bcMatrices, scalars); |
| _blen = blen; |
| } |
| |
| @Override |
| public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock[]> arg) |
| throws Exception |
| { |
| //lazy load of shipped class |
| if( _op == null ) { |
| Class<?> loadedClass = CodegenUtils.getClassSync(_className, _classBytes); |
| _op = (SpoofMultiAggregate) CodegenUtils.createInstance(loadedClass); |
| } |
| |
| //execute core operation |
| ArrayList<MatrixBlock> inputs = getAllMatrixInputs(arg._1(), arg._2()); |
| MatrixBlock blkOut = new MatrixBlock(); |
| long rix = (arg._1().getRowIndex()-1) * _blen; //0-based |
| blkOut = _op.execute(inputs, _scalars, blkOut, 1, rix); |
| |
| return new Tuple2<>(arg._1(), blkOut); |
| } |
| } |
| |
| private static class MultiAggAggregateFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> |
| { |
| private static final long serialVersionUID = 5978731867787952513L; |
| |
| private AggOp[] _ops = null; |
| |
| public MultiAggAggregateFunction( AggOp[] ops ) { |
| _ops = ops; |
| } |
| |
| @Override |
| public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1) |
| throws Exception |
| { |
| //prepare combiner block |
| if( arg0.getNumRows() <= 0 || arg0.getNumColumns() <= 0) { |
| arg0.copy(arg1); |
| return arg0; |
| } |
| else if( arg1.getNumRows() <= 0 || arg1.getNumColumns() <= 0 ) { |
| return arg0; |
| } |
| |
| //aggregate second input (in-place) |
| SpoofMultiAggregate.aggregatePartialResults(_ops, arg0, arg1); |
| |
| return arg0; |
| } |
| } |
| |
| private static class OuterProductFunction extends SpoofFunction |
| implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = -8209188316939435099L; |
| |
| private SpoofOperator _op = null; |
| |
| public OuterProductFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) { |
| super(className, classBytes, bcInd, bcMatrices, scalars); |
| } |
| |
| @Override |
| public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg) |
| throws Exception |
| { |
| //lazy load of shipped class |
| if( _op == null ) { |
| Class<?> loadedClass = CodegenUtils.getClassSync(_className, _classBytes); |
| _op = CodegenUtils.createInstance(loadedClass); |
| } |
| |
| List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<>(); |
| while(arg.hasNext()) |
| { |
| Tuple2<MatrixIndexes,MatrixBlock[]> tmp = arg.next(); |
| MatrixIndexes ixIn = tmp._1(); |
| MatrixBlock[] blkIn = tmp._2(); |
| MatrixBlock blkOut = new MatrixBlock(); |
| |
| ArrayList<MatrixBlock> inputs = getAllMatrixInputs(ixIn, blkIn, true); |
| //execute core operation |
| if(((SpoofOuterProduct)_op).getOuterProdType()==OutProdType.AGG_OUTER_PRODUCT) { |
| ScalarObject obj = _op.execute(inputs, _scalars,1); |
| blkOut.reset(1, 1); |
| blkOut.quickSetValue(0, 0, obj.getDoubleValue()); |
| } |
| else { |
| blkOut = _op.execute(inputs, _scalars, blkOut); |
| } |
| |
| ret.add(new Tuple2<>(createOutputIndexes(ixIn,_op), blkOut)); |
| } |
| |
| return ret.iterator(); |
| } |
| |
| private static MatrixIndexes createOutputIndexes(MatrixIndexes in, SpoofOperator spoofOp) { |
| if( ((SpoofOuterProduct)spoofOp).getOuterProdType() == OutProdType.LEFT_OUTER_PRODUCT ) |
| return new MatrixIndexes(in.getColumnIndex(), 1); |
| else if ( ((SpoofOuterProduct)spoofOp).getOuterProdType() == OutProdType.RIGHT_OUTER_PRODUCT) |
| return new MatrixIndexes(in.getRowIndex(), 1); |
| else |
| return in; |
| } |
| } |
| |
| public static class ReplicateRightFactorFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = -7295989688796126442L; |
| |
| private final long _len; |
| private final long _blen; |
| |
| public ReplicateRightFactorFunction(long len, long blen) { |
| _len = len; |
| _blen = blen; |
| } |
| |
| @Override |
| public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) |
| throws Exception |
| { |
| LinkedList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new LinkedList<>(); |
| MatrixIndexes ixIn = arg0._1(); |
| MatrixBlock blkIn = arg0._2(); |
| |
| long numBlocks = (long) Math.ceil((double)_len/_blen); |
| |
| //replicate wrt # row blocks in LHS |
| long j = ixIn.getRowIndex(); |
| for( long i=1; i<=numBlocks; i++ ) { |
| MatrixIndexes tmpix = new MatrixIndexes(i, j); |
| MatrixBlock tmpblk = blkIn; |
| ret.add( new Tuple2<>(tmpix, tmpblk) ); |
| } |
| |
| //output list of new tuples |
| return ret.iterator(); |
| } |
| } |
| |
| public static AggregateOperator getAggregateOperator(AggOp aggop) { |
| if( aggop == AggOp.SUM || aggop == AggOp.SUM_SQ ) |
| return new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.NONE); |
| else if( aggop == AggOp.MIN ) |
| return new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject(BuiltinCode.MIN), CorrectionLocationType.NONE); |
| else if( aggop == AggOp.MAX ) |
| return new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject(BuiltinCode.MAX), CorrectionLocationType.NONE); |
| return null; |
| } |
| } |