blob: 7f2210895ca69f9ae7fd998538a799a650aae657 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package org.apache.sysds.runtime.instructions.spark;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
import org.apache.sysds.lops.WeightedDivMM;
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSigmoid.WSigmoidType;
import org.apache.sysds.lops.WeightedSquaredLoss;
import org.apache.sysds.lops.WeightedSquaredLoss.WeightsType;
import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.lops.WeightedUnaryMM.WUMMType;
import org.apache.sysds.lops.WeightedUnaryMMR;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateBlockFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.Iterator;
public class QuaternarySPInstruction extends ComputationSPInstruction {
private CPOperand _input4 = null;
private boolean _cacheU = false;
private boolean _cacheV = false;
private QuaternarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
CPOperand out, boolean cacheU, boolean cacheV, String opcode, String str) {
super(SPType.Quaternary, op, in1, in2, in3, out, opcode, str);
_input4 = in4;
_cacheU = cacheU;
_cacheV = cacheV;
public static QuaternarySPInstruction parseInstruction( String str ) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
//validity check
if( !InstructionUtils.isDistQuaternaryOpcode(opcode) ) {
throw new DMLRuntimeException("Quaternary.parseInstruction():: Unknown opcode " + opcode);
//instruction parsing
if( WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) //wsloss
|| WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode) )
boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode);
//check number of fields (4 inputs, output, type)
if( isRed )
InstructionUtils.checkNumFields ( parts, 8 );
InstructionUtils.checkNumFields ( parts, 6 );
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
WeightsType wtype = WeightsType.valueOf(parts[6]);
//in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
return new QuaternarySPInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
else if( WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //wumm
|| WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode) )
boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode);
//check number of fields (4 inputs, output, type)
if( isRed )
InstructionUtils.checkNumFields ( parts, 8 );
InstructionUtils.checkNumFields ( parts, 6 );
String uopcode = parts[1];
CPOperand in1 = new CPOperand(parts[2]);
CPOperand in2 = new CPOperand(parts[3]);
CPOperand in3 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
WUMMType wtype = WUMMType.valueOf(parts[6]);
//in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
return new QuaternarySPInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
else if( WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) //wdivmm
|| WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode) )
boolean isRed = opcode.startsWith("red");
//check number of fields (4 inputs, output, type)
if( isRed )
InstructionUtils.checkNumFields( parts, 8 );
InstructionUtils.checkNumFields( parts, 6 );
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
//in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
final WDivMMType wt = WDivMMType.valueOf(parts[6]);
QuaternaryOperator qop = (wt.hasScalar() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt));
return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
else //map/redwsigmoid, map/redwcemm
boolean isRed = opcode.startsWith("red");
int addInput4 = (opcode.endsWith("wcemm")) ? 1 : 0;
//check number of fields (3 or 4 inputs, output, type)
if( isRed )
InstructionUtils.checkNumFields( parts, 7 + addInput4 );
InstructionUtils.checkNumFields( parts, 5 + addInput4 );
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4 + addInput4]);
//in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true;
if( opcode.endsWith("wsigmoid") )
return new QuaternarySPInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
else if( opcode.endsWith("wcemm") ) {
CPOperand in4 = new CPOperand(parts[4]);
final WCeMMType wt = WCeMMType.valueOf(parts[6]);
QuaternaryOperator qop = (wt.hasFourInputs() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt));
return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
return null;
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
QuaternaryOperator qop = (QuaternaryOperator) _optr;
//tracking of rdds and broadcasts (for lineage maintenance)
ArrayList<String> rddVars = new ArrayList<>();
ArrayList<String> bcVars = new ArrayList<>();
JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
DataCharacteristics inMc = sec.getDataCharacteristics( input1.getName() );
long rlen = inMc.getRows();
long clen = inMc.getCols();
int blen = inMc.getBlocksize();
//pre-filter empty blocks (ultra-sparse matrices) for full aggregates
//(map/redwsloss, map/redwcemm); safe because theses ops produce a scalar
if( qop.wtype1 != null || qop.wtype4 != null ) {
in = in.filter(new FilterNonEmptyBlocksFunction());
//map-side only operation (one rdd input, two broadcasts)
if( WeightedSquaredLoss.OPCODE.equalsIgnoreCase(getOpcode())
|| WeightedSigmoid.OPCODE.equalsIgnoreCase(getOpcode())
|| WeightedDivMM.OPCODE.equalsIgnoreCase(getOpcode())
|| WeightedCrossEntropy.OPCODE.equalsIgnoreCase(getOpcode())
|| WeightedUnaryMM.OPCODE.equalsIgnoreCase(getOpcode()))
PartitionedBroadcast<MatrixBlock> bc1 = sec.getBroadcastForVariable( input2.getName() );
PartitionedBroadcast<MatrixBlock> bc2 = sec.getBroadcastForVariable( input3.getName() );
//partitioning-preserving mappartitions (key access required for broadcast loopkup)
boolean noKeyChange = (qop.wtype3 == null || qop.wtype3.isBasic()); //only wdivmm changes keys
out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), noKeyChange);
rddVars.add( input1.getName() );
bcVars.add( input2.getName() );
bcVars.add( input3.getName() );
//reduce-side operation (two/three/four rdd inputs, zero/one/two broadcasts)
PartitionedBroadcast<MatrixBlock> bc1 = _cacheU ? sec.getBroadcastForVariable( input2.getName() ) : null;
PartitionedBroadcast<MatrixBlock> bc2 = _cacheV ? sec.getBroadcastForVariable( input3.getName() ) : null;
JavaPairRDD<MatrixIndexes,MatrixBlock> inU = (!_cacheU) ? sec.getBinaryMatrixBlockRDDHandleForVariable( input2.getName() ) : null;
JavaPairRDD<MatrixIndexes,MatrixBlock> inV = (!_cacheV) ? sec.getBinaryMatrixBlockRDDHandleForVariable( input3.getName() ) : null;
JavaPairRDD<MatrixIndexes,MatrixBlock> inW = (qop.hasFourInputs() && !_input4.isLiteral()) ?
sec.getBinaryMatrixBlockRDDHandleForVariable( _input4.getName() ) : null;
//preparation of transposed and replicated U
if( inU != null )
inU = inU.flatMapToPair(new ReplicateBlockFunction(clen, blen, true));
//preparation of transposed and replicated V
if( inV != null )
inV = inV.mapToPair(new TransposeFactorIndexesFunction())
.flatMapToPair(new ReplicateBlockFunction(rlen, blen, false));
//functions calls w/ two rdd inputs
if( inU != null && inV == null && inW == null )
out = in.join(inU)
.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else if( inU == null && inV != null && inW == null )
out = in.join(inV)
.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else if( inU == null && inV == null && inW != null )
out = in.join(inW)
.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
//function calls w/ three rdd inputs
else if( inU != null && inV != null && inW == null )
out = in.join(inU).join(inV)
.mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if( inU != null && inV == null && inW != null )
out = in.join(inU).join(inW)
.mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if( inU == null && inV != null && inW != null )
out = in.join(inV).join(inW)
.mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if( inU == null && inV == null && inW == null ) {
out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), false);
//function call w/ four rdd inputs
else //need keys in case of wdivmm
out = in.join(inU).join(inV).join(inW)
.mapToPair(new RDDQuaternaryFunction4(qop));
//keep variable names for lineage maintenance
if( inU == null ) bcVars.add(input2.getName()); else rddVars.add(input2.getName());
if( inV == null ) bcVars.add(input3.getName()); else rddVars.add(input3.getName());
if( inW != null ) rddVars.add(_input4.getName());
//output handling, incl aggregation
if( qop.wtype1 != null || qop.wtype4 != null ) //map/redwsloss, map/redwcemm
//full aggregate and cast to scalar
MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
sec.setVariable(output.getName(), ret);
else //map/redwsigmoid, map/redwdivmm, map/redwumm
//aggregation if required (map/redwdivmm)
if( qop.wtype3 != null && !qop.wtype3.isBasic() )
out = RDDAggregateUtils.sumByKeyStable(out, false);
//put output RDD handle into symbol table
sec.setRDDHandleForVariable(output.getName(), out);
//maintain lineage information for output rdd
for( String rddVar : rddVars )
sec.addLineageRDD(output.getName(), rddVar);
for( String bcVar : bcVars )
sec.addLineageBroadcast(output.getName(), bcVar);
//update matrix characteristics
updateOutputDataCharacteristics(sec, qop);
private void updateOutputDataCharacteristics(SparkExecutionContext sec, QuaternaryOperator qop) {
DataCharacteristics mcIn1 = sec.getDataCharacteristics(input1.getName());
DataCharacteristics mcIn2 = sec.getDataCharacteristics(input2.getName());
DataCharacteristics mcIn3 = sec.getDataCharacteristics(input3.getName());
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
if( qop.wtype2 != null || qop.wtype5 != null ) {
//output size determined by main input
mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getBlocksize(), mcIn1.getBlocksize());
else if(qop.wtype3 != null ) { //wdivmm
long rank = qop.wtype3.isLeft() ? mcIn3.getCols() : mcIn2.getCols();
DataCharacteristics mcTmp = qop.wtype3.computeOutputCharacteristics(mcIn1.getRows(), mcIn1.getCols(), rank);
mcOut.set(mcTmp.getRows(), mcTmp.getCols(), mcIn1.getBlocksize(), mcIn1.getBlocksize());
private abstract static class RDDQuaternaryBaseFunction implements Serializable
private static final long serialVersionUID = -3175397651350954930L;
protected QuaternaryOperator _qop = null;
protected PartitionedBroadcast<MatrixBlock> _pmU = null;
protected PartitionedBroadcast<MatrixBlock> _pmV = null;
public RDDQuaternaryBaseFunction( QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) {
_qop = qop;
_pmU = bcU;
_pmV = bcV;
protected MatrixIndexes createOutputIndexes(MatrixIndexes in) {
if( _qop.wtype3 != null && !_qop.wtype3.isBasic() ){ //key change
boolean left = _qop.wtype3.isLeft();
return new MatrixIndexes(left?in.getColumnIndex():in.getRowIndex(), 1);
return in;
private static class RDDQuaternaryFunction1 extends RDDQuaternaryBaseFunction //one rdd input
implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock>
private static final long serialVersionUID = -8209188316939435099L;
public RDDQuaternaryFunction1( QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) {
super(qop, bcU, bcV);
public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg) {
return new RDDQuaternaryPartitionIterator(arg);
private class RDDQuaternaryPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>>
public RDDQuaternaryPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) {
MatrixIndexes ixIn = arg._1();
MatrixBlock blkIn = arg._2();
MatrixBlock blkOut = new MatrixBlock();
MatrixBlock mbU = _pmU.getBlock((int)ixIn.getRowIndex(), 1);
MatrixBlock mbV = _pmV.getBlock((int)ixIn.getColumnIndex(), 1);
//execute core operation
blkIn.quaternaryOperations(_qop, mbU, mbV, null, blkOut);
//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn);
return new Tuple2<>(ixOut, blkOut);
private static class RDDQuaternaryFunction2 extends RDDQuaternaryBaseFunction //two rdd input
implements PairFunction<Tuple2<MatrixIndexes, Tuple2<MatrixBlock,MatrixBlock>>, MatrixIndexes, MatrixBlock>
private static final long serialVersionUID = 7493974462943080693L;
public RDDQuaternaryFunction2( QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) {
super(qop, bcU, bcV);
public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg0) {
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn1 = arg0._2()._1();
MatrixBlock blkIn2 = arg0._2()._2();
MatrixBlock blkOut = new MatrixBlock();
MatrixBlock mbU = (_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2;
MatrixBlock mbV = (_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : blkIn2;
MatrixBlock mbW = (_qop.hasFourInputs()) ? blkIn2 : null;
//execute core operation
blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut);
//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn);
return new Tuple2<>(ixOut, blkOut);
private static class RDDQuaternaryFunction3 extends RDDQuaternaryBaseFunction //three rdd input
implements PairFunction<Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>>, MatrixIndexes, MatrixBlock>
private static final long serialVersionUID = -2294086455843773095L;
public RDDQuaternaryFunction3( QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) {
super(qop, bcU, bcV);
public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>> arg0) {
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn1 = arg0._2()._1()._1();
MatrixBlock blkIn2 = arg0._2()._1()._2();
MatrixBlock blkIn3 = arg0._2()._2();
MatrixBlock blkOut = new MatrixBlock();
MatrixBlock mbU = (_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2;
MatrixBlock mbV = (_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) :
(_pmU!=null)? blkIn2 : blkIn3;
MatrixBlock mbW = (_qop.hasFourInputs())? blkIn3 : null;
//execute core operation
blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut);
//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn);
return new Tuple2<>(ixOut, blkOut);
* Note: never called for wsigmoid/wdivmm (only wsloss)
private static class RDDQuaternaryFunction4 extends RDDQuaternaryBaseFunction //four rdd input
implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>,MatrixBlock>>,MatrixIndexes,MatrixBlock>
private static final long serialVersionUID = 7328911771600289250L;
public RDDQuaternaryFunction4( QuaternaryOperator qop ) {
super(qop, null, null);
public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>, MatrixBlock>> arg0)
MatrixIndexes ixIn1 = arg0._1();
MatrixBlock blkIn1 = arg0._2()._1()._1()._1();
MatrixBlock mbU = arg0._2()._1()._1()._2();
MatrixBlock mbV = arg0._2()._1()._2();
MatrixBlock mbW = arg0._2()._2();
MatrixBlock blkOut = new MatrixBlock();
//execute core operation
blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut);
//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn1);
return new Tuple2<>(ixOut, blkOut);
private static class TransposeFactorIndexesFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
private static final long serialVersionUID = -2571724736131823708L;
public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) {
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn = arg0._2();
//swap the matrix indexes
MatrixIndexes ixOut = new MatrixIndexes(ixIn.getColumnIndex(), ixIn.getRowIndex());
MatrixBlock blkOut = new MatrixBlock(blkIn);
//output new tuple
return new Tuple2<>(ixOut,blkOut);