| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.instructions.spark; |
| |
| import org.apache.spark.api.java.JavaPairRDD; |
| import org.apache.spark.api.java.function.PairFlatMapFunction; |
| import org.apache.spark.api.java.function.PairFunction; |
| import org.apache.sysds.lops.WeightedCrossEntropy; |
| import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType; |
| import org.apache.sysds.lops.WeightedDivMM; |
| import org.apache.sysds.lops.WeightedDivMM.WDivMMType; |
| import org.apache.sysds.lops.WeightedDivMMR; |
| import org.apache.sysds.lops.WeightedSigmoid; |
| import org.apache.sysds.lops.WeightedSigmoid.WSigmoidType; |
| import org.apache.sysds.lops.WeightedSquaredLoss; |
| import org.apache.sysds.lops.WeightedSquaredLoss.WeightsType; |
| import org.apache.sysds.lops.WeightedSquaredLossR; |
| import org.apache.sysds.lops.WeightedUnaryMM; |
| import org.apache.sysds.lops.WeightedUnaryMM.WUMMType; |
| import org.apache.sysds.lops.WeightedUnaryMMR; |
| import org.apache.sysds.runtime.DMLRuntimeException; |
| import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; |
| import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; |
| import org.apache.sysds.runtime.instructions.InstructionUtils; |
| import org.apache.sysds.runtime.instructions.cp.CPOperand; |
| import org.apache.sysds.runtime.instructions.cp.DoubleObject; |
| import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator; |
| import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast; |
| import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction; |
| import org.apache.sysds.runtime.instructions.spark.functions.ReplicateBlockFunction; |
| import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils; |
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| import org.apache.sysds.runtime.matrix.data.MatrixIndexes; |
| import org.apache.sysds.runtime.matrix.operators.Operator; |
| import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator; |
| import org.apache.sysds.runtime.meta.DataCharacteristics; |
| import scala.Tuple2; |
| |
| import java.io.Serializable; |
| import java.util.ArrayList; |
| import java.util.Iterator; |
| |
| public class QuaternarySPInstruction extends ComputationSPInstruction { |
| private CPOperand _input4 = null; |
| private boolean _cacheU = false; |
| private boolean _cacheV = false; |
| |
| private QuaternarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, |
| CPOperand out, boolean cacheU, boolean cacheV, String opcode, String str) { |
| super(SPType.Quaternary, op, in1, in2, in3, out, opcode, str); |
| _input4 = in4; |
| _cacheU = cacheU; |
| _cacheV = cacheV; |
| } |
| |
| public static QuaternarySPInstruction parseInstruction( String str ) { |
| String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); |
| String opcode = parts[0]; |
| |
| //validity check |
| if( !InstructionUtils.isDistQuaternaryOpcode(opcode) ) { |
| throw new DMLRuntimeException("Quaternary.parseInstruction():: Unknown opcode " + opcode); |
| } |
| |
| //instruction parsing |
| if( WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) //wsloss |
| || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode) ) |
| { |
| boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode); |
| |
| //check number of fields (4 inputs, output, type) |
| if( isRed ) |
| InstructionUtils.checkNumFields ( parts, 8 ); |
| else |
| InstructionUtils.checkNumFields ( parts, 6 ); |
| |
| CPOperand in1 = new CPOperand(parts[1]); |
| CPOperand in2 = new CPOperand(parts[2]); |
| CPOperand in3 = new CPOperand(parts[3]); |
| CPOperand in4 = new CPOperand(parts[4]); |
| CPOperand out = new CPOperand(parts[5]); |
| WeightsType wtype = WeightsType.valueOf(parts[6]); |
| |
| //in mappers always through distcache, in reducers through distcache/shuffle |
| boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true; |
| boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true; |
| |
| return new QuaternarySPInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, opcode, str); |
| } |
| else if( WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //wumm |
| || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode) ) |
| { |
| boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode); |
| |
| //check number of fields (4 inputs, output, type) |
| if( isRed ) |
| InstructionUtils.checkNumFields ( parts, 8 ); |
| else |
| InstructionUtils.checkNumFields ( parts, 6 ); |
| |
| String uopcode = parts[1]; |
| CPOperand in1 = new CPOperand(parts[2]); |
| CPOperand in2 = new CPOperand(parts[3]); |
| CPOperand in3 = new CPOperand(parts[4]); |
| CPOperand out = new CPOperand(parts[5]); |
| WUMMType wtype = WUMMType.valueOf(parts[6]); |
| |
| //in mappers always through distcache, in reducers through distcache/shuffle |
| boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true; |
| boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true; |
| |
| return new QuaternarySPInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, null, out, cacheU, cacheV, opcode, str); |
| } |
| else if( WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) //wdivmm |
| || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode) ) |
| { |
| boolean isRed = opcode.startsWith("red"); |
| |
| //check number of fields (4 inputs, output, type) |
| if( isRed ) |
| InstructionUtils.checkNumFields( parts, 8 ); |
| else |
| InstructionUtils.checkNumFields( parts, 6 ); |
| |
| CPOperand in1 = new CPOperand(parts[1]); |
| CPOperand in2 = new CPOperand(parts[2]); |
| CPOperand in3 = new CPOperand(parts[3]); |
| CPOperand in4 = new CPOperand(parts[4]); |
| CPOperand out = new CPOperand(parts[5]); |
| |
| //in mappers always through distcache, in reducers through distcache/shuffle |
| boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true; |
| boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true; |
| |
| final WDivMMType wt = WDivMMType.valueOf(parts[6]); |
| QuaternaryOperator qop = (wt.hasScalar() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt)); |
| return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str); |
| } |
| else //map/redwsigmoid, map/redwcemm |
| { |
| boolean isRed = opcode.startsWith("red"); |
| int addInput4 = (opcode.endsWith("wcemm")) ? 1 : 0; |
| |
| //check number of fields (3 or 4 inputs, output, type) |
| if( isRed ) |
| InstructionUtils.checkNumFields( parts, 7 + addInput4 ); |
| else |
| InstructionUtils.checkNumFields( parts, 5 + addInput4 ); |
| |
| CPOperand in1 = new CPOperand(parts[1]); |
| CPOperand in2 = new CPOperand(parts[2]); |
| CPOperand in3 = new CPOperand(parts[3]); |
| CPOperand out = new CPOperand(parts[4 + addInput4]); |
| |
| //in mappers always through distcache, in reducers through distcache/shuffle |
| boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true; |
| boolean cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true; |
| |
| if( opcode.endsWith("wsigmoid") ) |
| return new QuaternarySPInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str); |
| else if( opcode.endsWith("wcemm") ) { |
| CPOperand in4 = new CPOperand(parts[4]); |
| final WCeMMType wt = WCeMMType.valueOf(parts[6]); |
| QuaternaryOperator qop = (wt.hasFourInputs() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt)); |
| return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str); |
| } |
| } |
| |
| return null; |
| } |
| |
| @Override |
| public void processInstruction(ExecutionContext ec) { |
| SparkExecutionContext sec = (SparkExecutionContext)ec; |
| QuaternaryOperator qop = (QuaternaryOperator) _optr; |
| |
| //tracking of rdds and broadcasts (for lineage maintenance) |
| ArrayList<String> rddVars = new ArrayList<>(); |
| ArrayList<String> bcVars = new ArrayList<>(); |
| |
| JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() ); |
| JavaPairRDD<MatrixIndexes, MatrixBlock> out = null; |
| |
| DataCharacteristics inMc = sec.getDataCharacteristics( input1.getName() ); |
| long rlen = inMc.getRows(); |
| long clen = inMc.getCols(); |
| int blen = inMc.getBlocksize(); |
| |
| //pre-filter empty blocks (ultra-sparse matrices) for full aggregates |
| //(map/redwsloss, map/redwcemm); safe because theses ops produce a scalar |
| if( qop.wtype1 != null || qop.wtype4 != null ) { |
| in = in.filter(new FilterNonEmptyBlocksFunction()); |
| } |
| |
| //map-side only operation (one rdd input, two broadcasts) |
| if( WeightedSquaredLoss.OPCODE.equalsIgnoreCase(getOpcode()) |
| || WeightedSigmoid.OPCODE.equalsIgnoreCase(getOpcode()) |
| || WeightedDivMM.OPCODE.equalsIgnoreCase(getOpcode()) |
| || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(getOpcode()) |
| || WeightedUnaryMM.OPCODE.equalsIgnoreCase(getOpcode())) |
| { |
| PartitionedBroadcast<MatrixBlock> bc1 = sec.getBroadcastForVariable( input2.getName() ); |
| PartitionedBroadcast<MatrixBlock> bc2 = sec.getBroadcastForVariable( input3.getName() ); |
| |
| //partitioning-preserving mappartitions (key access required for broadcast loopkup) |
| boolean noKeyChange = (qop.wtype3 == null || qop.wtype3.isBasic()); //only wdivmm changes keys |
| out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), noKeyChange); |
| |
| rddVars.add( input1.getName() ); |
| bcVars.add( input2.getName() ); |
| bcVars.add( input3.getName() ); |
| } |
| //reduce-side operation (two/three/four rdd inputs, zero/one/two broadcasts) |
| else |
| { |
| PartitionedBroadcast<MatrixBlock> bc1 = _cacheU ? sec.getBroadcastForVariable( input2.getName() ) : null; |
| PartitionedBroadcast<MatrixBlock> bc2 = _cacheV ? sec.getBroadcastForVariable( input3.getName() ) : null; |
| JavaPairRDD<MatrixIndexes,MatrixBlock> inU = (!_cacheU) ? sec.getBinaryMatrixBlockRDDHandleForVariable( input2.getName() ) : null; |
| JavaPairRDD<MatrixIndexes,MatrixBlock> inV = (!_cacheV) ? sec.getBinaryMatrixBlockRDDHandleForVariable( input3.getName() ) : null; |
| JavaPairRDD<MatrixIndexes,MatrixBlock> inW = (qop.hasFourInputs() && !_input4.isLiteral()) ? |
| sec.getBinaryMatrixBlockRDDHandleForVariable( _input4.getName() ) : null; |
| |
| //preparation of transposed and replicated U |
| if( inU != null ) |
| inU = inU.flatMapToPair(new ReplicateBlockFunction(clen, blen, true)); |
| |
| //preparation of transposed and replicated V |
| if( inV != null ) |
| inV = inV.mapToPair(new TransposeFactorIndexesFunction()) |
| .flatMapToPair(new ReplicateBlockFunction(rlen, blen, false)); |
| |
| //functions calls w/ two rdd inputs |
| if( inU != null && inV == null && inW == null ) |
| out = in.join(inU) |
| .mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2)); |
| else if( inU == null && inV != null && inW == null ) |
| out = in.join(inV) |
| .mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2)); |
| else if( inU == null && inV == null && inW != null ) |
| out = in.join(inW) |
| .mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2)); |
| //function calls w/ three rdd inputs |
| else if( inU != null && inV != null && inW == null ) |
| out = in.join(inU).join(inV) |
| .mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2)); |
| else if( inU != null && inV == null && inW != null ) |
| out = in.join(inU).join(inW) |
| .mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2)); |
| else if( inU == null && inV != null && inW != null ) |
| out = in.join(inV).join(inW) |
| .mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2)); |
| else if( inU == null && inV == null && inW == null ) { |
| out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), false); |
| } |
| //function call w/ four rdd inputs |
| else //need keys in case of wdivmm |
| out = in.join(inU).join(inV).join(inW) |
| .mapToPair(new RDDQuaternaryFunction4(qop)); |
| |
| //keep variable names for lineage maintenance |
| if( inU == null ) bcVars.add(input2.getName()); else rddVars.add(input2.getName()); |
| if( inV == null ) bcVars.add(input3.getName()); else rddVars.add(input3.getName()); |
| if( inW != null ) rddVars.add(_input4.getName()); |
| } |
| |
| //output handling, incl aggregation |
| if( qop.wtype1 != null || qop.wtype4 != null ) //map/redwsloss, map/redwcemm |
| { |
| //full aggregate and cast to scalar |
| MatrixBlock tmp = RDDAggregateUtils.sumStable(out); |
| DoubleObject ret = new DoubleObject(tmp.getValue(0, 0)); |
| sec.setVariable(output.getName(), ret); |
| } |
| else //map/redwsigmoid, map/redwdivmm, map/redwumm |
| { |
| //aggregation if required (map/redwdivmm) |
| if( qop.wtype3 != null && !qop.wtype3.isBasic() ) |
| out = RDDAggregateUtils.sumByKeyStable(out, false); |
| |
| //put output RDD handle into symbol table |
| sec.setRDDHandleForVariable(output.getName(), out); |
| //maintain lineage information for output rdd |
| for( String rddVar : rddVars ) |
| sec.addLineageRDD(output.getName(), rddVar); |
| for( String bcVar : bcVars ) |
| sec.addLineageBroadcast(output.getName(), bcVar); |
| |
| //update matrix characteristics |
| updateOutputDataCharacteristics(sec, qop); |
| } |
| } |
| |
| private void updateOutputDataCharacteristics(SparkExecutionContext sec, QuaternaryOperator qop) { |
| DataCharacteristics mcIn1 = sec.getDataCharacteristics(input1.getName()); |
| DataCharacteristics mcIn2 = sec.getDataCharacteristics(input2.getName()); |
| DataCharacteristics mcIn3 = sec.getDataCharacteristics(input3.getName()); |
| DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName()); |
| if( qop.wtype2 != null || qop.wtype5 != null ) { |
| //output size determined by main input |
| mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getBlocksize(), mcIn1.getBlocksize()); |
| } |
| else if(qop.wtype3 != null ) { //wdivmm |
| long rank = qop.wtype3.isLeft() ? mcIn3.getCols() : mcIn2.getCols(); |
| DataCharacteristics mcTmp = qop.wtype3.computeOutputCharacteristics(mcIn1.getRows(), mcIn1.getCols(), rank); |
| mcOut.set(mcTmp.getRows(), mcTmp.getCols(), mcIn1.getBlocksize(), mcIn1.getBlocksize()); |
| } |
| } |
| |
| private abstract static class RDDQuaternaryBaseFunction implements Serializable |
| { |
| private static final long serialVersionUID = -3175397651350954930L; |
| |
| protected QuaternaryOperator _qop = null; |
| protected PartitionedBroadcast<MatrixBlock> _pmU = null; |
| protected PartitionedBroadcast<MatrixBlock> _pmV = null; |
| |
| public RDDQuaternaryBaseFunction( QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) { |
| _qop = qop; |
| _pmU = bcU; |
| _pmV = bcV; |
| } |
| |
| protected MatrixIndexes createOutputIndexes(MatrixIndexes in) { |
| if( _qop.wtype3 != null && !_qop.wtype3.isBasic() ){ //key change |
| boolean left = _qop.wtype3.isLeft(); |
| return new MatrixIndexes(left?in.getColumnIndex():in.getRowIndex(), 1); |
| } |
| return in; |
| } |
| } |
| |
| private static class RDDQuaternaryFunction1 extends RDDQuaternaryBaseFunction //one rdd input |
| implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = -8209188316939435099L; |
| |
| public RDDQuaternaryFunction1( QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) { |
| super(qop, bcU, bcV); |
| } |
| |
| @Override |
| public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg) { |
| return new RDDQuaternaryPartitionIterator(arg); |
| } |
| |
| private class RDDQuaternaryPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> |
| { |
| public RDDQuaternaryPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) { |
| super(in); |
| } |
| |
| @Override |
| protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) { |
| MatrixIndexes ixIn = arg._1(); |
| MatrixBlock blkIn = arg._2(); |
| MatrixBlock blkOut = new MatrixBlock(); |
| MatrixBlock mbU = _pmU.getBlock((int)ixIn.getRowIndex(), 1); |
| MatrixBlock mbV = _pmV.getBlock((int)ixIn.getColumnIndex(), 1); |
| |
| //execute core operation |
| blkIn.quaternaryOperations(_qop, mbU, mbV, null, blkOut); |
| |
| //create return tuple |
| MatrixIndexes ixOut = createOutputIndexes(ixIn); |
| return new Tuple2<>(ixOut, blkOut); |
| } |
| } |
| } |
| |
| private static class RDDQuaternaryFunction2 extends RDDQuaternaryBaseFunction //two rdd input |
| implements PairFunction<Tuple2<MatrixIndexes, Tuple2<MatrixBlock,MatrixBlock>>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = 7493974462943080693L; |
| |
| public RDDQuaternaryFunction2( QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) { |
| super(qop, bcU, bcV); |
| } |
| |
| @Override |
| public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg0) { |
| MatrixIndexes ixIn = arg0._1(); |
| MatrixBlock blkIn1 = arg0._2()._1(); |
| MatrixBlock blkIn2 = arg0._2()._2(); |
| MatrixBlock blkOut = new MatrixBlock(); |
| MatrixBlock mbU = (_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2; |
| MatrixBlock mbV = (_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : blkIn2; |
| MatrixBlock mbW = (_qop.hasFourInputs()) ? blkIn2 : null; |
| |
| //execute core operation |
| blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut); |
| |
| //create return tuple |
| MatrixIndexes ixOut = createOutputIndexes(ixIn); |
| return new Tuple2<>(ixOut, blkOut); |
| } |
| } |
| |
| private static class RDDQuaternaryFunction3 extends RDDQuaternaryBaseFunction //three rdd input |
| implements PairFunction<Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = -2294086455843773095L; |
| |
| public RDDQuaternaryFunction3( QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) { |
| super(qop, bcU, bcV); |
| } |
| |
| @Override |
| public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>> arg0) { |
| MatrixIndexes ixIn = arg0._1(); |
| MatrixBlock blkIn1 = arg0._2()._1()._1(); |
| MatrixBlock blkIn2 = arg0._2()._1()._2(); |
| MatrixBlock blkIn3 = arg0._2()._2(); |
| MatrixBlock blkOut = new MatrixBlock(); |
| MatrixBlock mbU = (_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2; |
| MatrixBlock mbV = (_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : |
| (_pmU!=null)? blkIn2 : blkIn3; |
| MatrixBlock mbW = (_qop.hasFourInputs())? blkIn3 : null; |
| |
| //execute core operation |
| blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut); |
| |
| //create return tuple |
| MatrixIndexes ixOut = createOutputIndexes(ixIn); |
| return new Tuple2<>(ixOut, blkOut); |
| } |
| } |
| |
| /** |
| * Note: never called for wsigmoid/wdivmm (only wsloss) |
| */ |
| private static class RDDQuaternaryFunction4 extends RDDQuaternaryBaseFunction //four rdd input |
| implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>,MatrixBlock>>,MatrixIndexes,MatrixBlock> |
| { |
| private static final long serialVersionUID = 7328911771600289250L; |
| |
| public RDDQuaternaryFunction4( QuaternaryOperator qop ) { |
| super(qop, null, null); |
| } |
| |
| @Override |
| public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>, MatrixBlock>> arg0) |
| { |
| MatrixIndexes ixIn1 = arg0._1(); |
| MatrixBlock blkIn1 = arg0._2()._1()._1()._1(); |
| MatrixBlock mbU = arg0._2()._1()._1()._2(); |
| MatrixBlock mbV = arg0._2()._1()._2(); |
| MatrixBlock mbW = arg0._2()._2(); |
| MatrixBlock blkOut = new MatrixBlock(); |
| |
| //execute core operation |
| blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut); |
| |
| //create return tuple |
| MatrixIndexes ixOut = createOutputIndexes(ixIn1); |
| return new Tuple2<>(ixOut, blkOut); |
| } |
| } |
| |
| private static class TransposeFactorIndexesFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> |
| { |
| private static final long serialVersionUID = -2571724736131823708L; |
| |
| @Override |
| public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) { |
| MatrixIndexes ixIn = arg0._1(); |
| MatrixBlock blkIn = arg0._2(); |
| |
| //swap the matrix indexes |
| MatrixIndexes ixOut = new MatrixIndexes(ixIn.getColumnIndex(), ixIn.getRowIndex()); |
| MatrixBlock blkOut = new MatrixBlock(blkIn); |
| |
| //output new tuple |
| return new Tuple2<>(ixOut,blkOut); |
| } |
| } |
| } |