blob: e5e6b1cb02857f013d23db8cf655ac276c5f0345 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.spark;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.sysds.common.Types.CorrectionLocationType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupJoin;
import org.apache.sysds.runtime.instructions.spark.functions.ExtractGroupNWeights;
import org.apache.sysds.runtime.instructions.spark.functions.PerformGroupByAggInCombiner;
import org.apache.sysds.runtime.instructions.spark.functions.PerformGroupByAggInReducer;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixCell;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.data.WeightedCell;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.decode.Decoder;
import org.apache.sysds.runtime.transform.decode.DecoderFactory;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.transform.meta.TfOffsetMap;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction {
protected HashMap<String, String> params;
// removeEmpty-specific attributes
private boolean _bRmEmptyBC = false;
ParameterizedBuiltinSPInstruction(Operator op, HashMap<String, String> paramsMap, CPOperand out, String opcode,
String istr, boolean bRmEmptyBC) {
super(SPType.ParameterizedBuiltin, op, null, null, out, opcode, istr);
params = paramsMap;
_bRmEmptyBC = bRmEmptyBC;
}
public HashMap<String,String> getParams() { return params; }
public static HashMap<String, String> constructParameterMap(String[] params) {
// process all elements in "params" except first(opcode) and last(output)
HashMap<String,String> paramMap = new HashMap<>();
// all parameters are of form <name=value>
String[] parts;
for ( int i=1; i <= params.length-2; i++ ) {
parts = params[i].split(Lop.NAME_VALUE_SEPARATOR);
paramMap.put(parts[0], parts[1]);
}
return paramMap;
}
public static ParameterizedBuiltinSPInstruction parseInstruction ( String str ) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
// first part is always the opcode
String opcode = parts[0];
if( opcode.equalsIgnoreCase("mapgroupedagg") )
{
CPOperand target = new CPOperand( parts[1] );
CPOperand groups = new CPOperand( parts[2] );
CPOperand out = new CPOperand( parts[3] );
HashMap<String,String> paramsMap = new HashMap<>();
paramsMap.put(Statement.GAGG_TARGET, target.getName());
paramsMap.put(Statement.GAGG_GROUPS, groups.getName());
paramsMap.put(Statement.GAGG_NUM_GROUPS, parts[4]);
Operator op = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.LASTCOLUMN);
return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
}
else
{
// last part is always the output
CPOperand out = new CPOperand( parts[parts.length-1] );
// process remaining parts and build a hash map
HashMap<String,String> paramsMap = constructParameterMap(parts);
// determine the appropriate value function
ValueFunction func = null;
if ( opcode.equalsIgnoreCase("groupedagg")) {
// check for mandatory arguments
String fnStr = paramsMap.get("fn");
if ( fnStr == null )
throw new DMLRuntimeException("Function parameter is missing in groupedAggregate.");
if ( fnStr.equalsIgnoreCase("centralmoment") ) {
if ( paramsMap.get("order") == null )
throw new DMLRuntimeException("Mandatory \"order\" must be specified when fn=\"centralmoment\" in groupedAggregate.");
}
Operator op = InstructionUtils.parseGroupedAggOperator(fnStr, paramsMap.get("order"));
return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
}
else if (opcode.equalsIgnoreCase("rmempty")) {
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str,
parts.length > 6 ? Boolean.parseBoolean(parts[5]) : false);
}
else if (opcode.equalsIgnoreCase("rexpand")
|| opcode.equalsIgnoreCase("replace")
|| opcode.equalsIgnoreCase("lowertri")
|| opcode.equalsIgnoreCase("uppertri")
|| opcode.equalsIgnoreCase("transformapply")
|| opcode.equalsIgnoreCase("transformdecode")) {
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false);
}
else {
throw new DMLRuntimeException("Unknown opcode (" + opcode + ") for ParameterizedBuiltin Instruction.");
}
}
}
@Override
@SuppressWarnings("unchecked")
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
String opcode = getOpcode();
//opcode guaranteed to be a valid opcode (see parsing)
if( opcode.equalsIgnoreCase("mapgroupedagg") )
{
//get input rdd handle
String targetVar = params.get(Statement.GAGG_TARGET);
String groupsVar = params.get(Statement.GAGG_GROUPS);
JavaPairRDD<MatrixIndexes,MatrixBlock> target = sec.getBinaryMatrixBlockRDDHandleForVariable(targetVar);
PartitionedBroadcast<MatrixBlock> groups = sec.getBroadcastForVariable(groupsVar);
DataCharacteristics mc1 = sec.getDataCharacteristics( targetVar );
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
CPOperand ngrpOp = new CPOperand(params.get(Statement.GAGG_NUM_GROUPS));
int ngroups = (int)sec.getScalarInput(ngrpOp).getLongValue();
//single-block aggregation
if( ngroups <= mc1.getBlocksize() && mc1.getCols() <= mc1.getBlocksize() ) {
//execute map grouped aggregate
JavaRDD<MatrixBlock> out = target.map(new RDDMapGroupedAggFunction2(groups, _optr, ngroups));
MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
//put output block into symbol table (no lineage because single block)
//this also includes implicit maintenance of matrix characteristics
sec.setMatrixOutput(output.getName(), out2);
}
//multi-block aggregation
else {
//execute map grouped aggregate
JavaPairRDD<MatrixIndexes, MatrixBlock> out = target
.flatMapToPair(new RDDMapGroupedAggFunction(groups, _optr, ngroups, mc1.getBlocksize()));
out = RDDAggregateUtils.sumByKeyStable(out, false);
//updated characteristics and handle outputs
mcOut.set(ngroups, mc1.getCols(), mc1.getBlocksize(), -1);
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD( output.getName(), targetVar );
sec.addLineageBroadcast( output.getName(), groupsVar );
}
}
else if ( opcode.equalsIgnoreCase("groupedagg") )
{
boolean broadcastGroups = Boolean.parseBoolean(params.get("broadcast"));
//get input rdd handle
String groupsVar = params.get(Statement.GAGG_GROUPS);
JavaPairRDD<MatrixIndexes,MatrixBlock> target = sec.getBinaryMatrixBlockRDDHandleForVariable( params.get(Statement.GAGG_TARGET) );
JavaPairRDD<MatrixIndexes,MatrixBlock> groups = broadcastGroups ? null : sec.getBinaryMatrixBlockRDDHandleForVariable( groupsVar );
JavaPairRDD<MatrixIndexes,MatrixBlock> weights = null;
DataCharacteristics mc1 = sec.getDataCharacteristics( params.get(Statement.GAGG_TARGET) );
DataCharacteristics mc2 = sec.getDataCharacteristics( groupsVar );
if(mc1.dimsKnown() && mc2.dimsKnown() && (mc1.getRows() != mc2.getRows() || mc2.getCols() !=1)) {
throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target and groups.");
}
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
JavaPairRDD<MatrixIndexes, WeightedCell> groupWeightedCells = null;
// Step 1: First extract groupWeightedCells from group, target and weights
if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
weights = sec.getBinaryMatrixBlockRDDHandleForVariable( params.get(Statement.GAGG_WEIGHTS) );
DataCharacteristics mc3 = sec.getDataCharacteristics( params.get(Statement.GAGG_WEIGHTS) );
if(mc1.dimsKnown() && mc3.dimsKnown() && (mc1.getRows() != mc3.getRows() || mc1.getCols() != mc3.getCols())) {
throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target, groups, and weights.");
}
groupWeightedCells = groups.join(target).join(weights)
.flatMapToPair(new ExtractGroupNWeights());
}
else //input vector or matrix
{
String ngroupsStr = params.get(Statement.GAGG_NUM_GROUPS);
long ngroups = (ngroupsStr != null) ? (long) Double.parseDouble(ngroupsStr) : -1;
//execute basic grouped aggregate (extract and preagg)
if( broadcastGroups ) {
PartitionedBroadcast<MatrixBlock> pbm = sec.getBroadcastForVariable(groupsVar);
groupWeightedCells = target
.flatMapToPair(new ExtractGroupBroadcast(pbm, mc1.getBlocksize(), ngroups, _optr));
}
else { //general case
//replicate groups if necessary
if( mc1.getNumColBlocks() > 1 ) {
groups = groups.flatMapToPair(
new ReplicateVectorFunction(false, mc1.getNumColBlocks() ));
}
groupWeightedCells = groups.join(target)
.flatMapToPair(new ExtractGroupJoin(mc1.getBlocksize(), ngroups, _optr));
}
}
// Step 2: Make sure we have blen required while creating <MatrixIndexes, MatrixCell>
if(mc1.getBlocksize() == -1) {
throw new DMLRuntimeException("The block sizes are not specified for grouped aggregate");
}
int blen = mc1.getBlocksize();
// Step 3: Now perform grouped aggregate operation (either on combiner side or reducer side)
JavaPairRDD<MatrixIndexes, MatrixCell> out = null;
if(_optr instanceof CMOperator && ((CMOperator) _optr).isPartialAggregateOperator()
|| _optr instanceof AggregateOperator ) {
out = groupWeightedCells.reduceByKey(new PerformGroupByAggInCombiner(_optr))
.mapValues(new CreateMatrixCell(blen, _optr));
}
else {
// Use groupby key because partial aggregation is not supported
out = groupWeightedCells.groupByKey()
.mapValues(new PerformGroupByAggInReducer(_optr))
.mapValues(new CreateMatrixCell(blen, _optr));
}
// Step 4: Set output characteristics and rdd handle
setOutputCharacteristicsForGroupedAgg(mc1, mcOut, out);
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD( output.getName(), params.get(Statement.GAGG_TARGET) );
sec.addLineage( output.getName(), groupsVar, broadcastGroups );
if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_WEIGHTS) );
}
}
else if ( opcode.equalsIgnoreCase("rmempty") )
{
String rddInVar = params.get("target");
String rddOffVar = params.get("offset");
boolean rows = sec.getScalarInput(params.get("margin"), ValueType.STRING, true).getStringValue().equals("rows");
boolean emptyReturn = Boolean.parseBoolean(params.get("empty.return").toLowerCase());
long maxDim = sec.getScalarInput(params.get("maxdim"), ValueType.FP64, false).getLongValue();
DataCharacteristics mcIn = sec.getDataCharacteristics(rddInVar);
if( maxDim > 0 ) //default case
{
//get input rdd handle
JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable( rddInVar );
JavaPairRDD<MatrixIndexes,MatrixBlock> off;
PartitionedBroadcast<MatrixBlock> broadcastOff;
long blen = mcIn.getBlocksize();
long numRep = (long)Math.ceil( rows ? (double)mcIn.getCols()/blen : (double)mcIn.getRows()/blen);
//execute remove empty rows/cols operation
JavaPairRDD<MatrixIndexes,MatrixBlock> out;
if(_bRmEmptyBC){
broadcastOff = sec.getBroadcastForVariable( rddOffVar );
// Broadcast offset vector
out = in
.flatMapToPair(new RDDRemoveEmptyFunctionInMem(rows, maxDim, blen, broadcastOff));
}
else {
off = sec.getBinaryMatrixBlockRDDHandleForVariable( rddOffVar );
out = in
.join( off.flatMapToPair(new ReplicateVectorFunction(!rows,numRep)) )
.flatMapToPair(new RDDRemoveEmptyFunction(rows, maxDim, blen));
}
out = RDDAggregateUtils.mergeByKey(out, false);
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddInVar);
if(!_bRmEmptyBC)
sec.addLineageRDD(output.getName(), rddOffVar);
else
sec.addLineageBroadcast(output.getName(), rddOffVar);
//update output statistics (required for correctness)
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
mcOut.set(rows?maxDim:mcIn.getRows(), rows?mcIn.getCols():maxDim, (int)blen, mcIn.getNonZeros());
}
else //special case: empty output (ensure valid dims)
{
int n = emptyReturn ? 1 : 0;
MatrixBlock out = new MatrixBlock(rows?n:(int)mcIn.getRows(), rows?(int)mcIn.getCols():n, true);
sec.setMatrixOutput(output.getName(), out);
}
}
else if ( opcode.equalsIgnoreCase("replace") )
{
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
//execute replace operation
double pattern = Double.parseDouble( params.get("pattern") );
double replacement = Double.parseDouble( params.get("replacement") );
JavaPairRDD<MatrixIndexes,MatrixBlock> out =
in1.mapValues(new RDDReplaceFunction(pattern, replacement));
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), params.get("target"));
//update output statistics (required for correctness)
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
mcOut.set(mcIn.getRows(), mcIn.getCols(), mcIn.getBlocksize(),
(pattern!=0 && replacement!=0)?mcIn.getNonZeros():-1);
}
else if ( opcode.equalsIgnoreCase("lowertri") || opcode.equalsIgnoreCase("uppertri") )
{
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
boolean lower = opcode.equalsIgnoreCase("lowertri");
boolean diag = Boolean.parseBoolean(params.get("diag"));
boolean values = Boolean.parseBoolean(params.get("values"));
JavaPairRDD<MatrixIndexes,MatrixBlock> out = in1.mapPartitionsToPair(
new RDDExtractTriangularFunction(lower, diag, values), true);
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), params.get("target"));
//update output statistics (required for correctness)
sec.getDataCharacteristics(output.getName()).setDimension(mcIn.getRows(), mcIn.getCols());
}
else if ( opcode.equalsIgnoreCase("rexpand") )
{
String rddInVar = params.get("target");
//get input rdd handle
JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable( rddInVar );
DataCharacteristics mcIn = sec.getDataCharacteristics(rddInVar);
double maxVal = Double.parseDouble( params.get("max") );
long lmaxVal = UtilFunctions.toLong(maxVal);
boolean dirRows = params.get("dir").equals("rows");
boolean cast = Boolean.parseBoolean(params.get("cast"));
boolean ignore = Boolean.parseBoolean(params.get("ignore"));
long blen = mcIn.getBlocksize();
//repartition input vector for higher degree of parallelism
//(avoid scenarios where few input partitions create huge outputs)
DataCharacteristics mcTmp = new MatrixCharacteristics(dirRows?lmaxVal:mcIn.getRows(),
dirRows?mcIn.getRows():lmaxVal, (int)blen, mcIn.getRows());
int numParts = (int)Math.min(SparkUtils.getNumPreferredPartitions(mcTmp, in), mcIn.getNumBlocks());
if( numParts > in.getNumPartitions()*2 )
in = in.repartition(numParts);
//execute rexpand rows/cols operation (no shuffle required because outputs are
//block-aligned with the input, i.e., one input block generates n output blocks)
JavaPairRDD<MatrixIndexes,MatrixBlock> out = in
.flatMapToPair(new RDDRExpandFunction(maxVal, dirRows, cast, ignore, blen));
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddInVar);
//update output statistics (required for correctness)
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
mcOut.set(dirRows?lmaxVal:mcIn.getRows(), dirRows?mcIn.getRows():lmaxVal, (int)blen, -1);
}
else if ( opcode.equalsIgnoreCase("transformapply") )
{
//get input RDD and meta data
FrameObject fo = sec.getFrameObject(params.get("target"));
JavaPairRDD<Long,FrameBlock> in = (JavaPairRDD<Long,FrameBlock>)
sec.getRDDHandleForFrameObject(fo, FileFormat.BINARY);
FrameBlock meta = sec.getFrameInput(params.get("meta"));
DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
String[] colnames = !TfMetaUtils.isIDSpec(params.get("spec")) ?
in.lookup(1L).get(0).getColumnNames() : null;
//compute omit offset map for block shifts
TfOffsetMap omap = null;
if( TfMetaUtils.containsOmitSpec(params.get("spec"), colnames) ) {
omap = new TfOffsetMap(SparkUtils.toIndexedLong(in.mapToPair(
new RDDTransformApplyOffsetFunction(params.get("spec"), colnames)).collect()));
}
//create encoder broadcast (avoiding replication per task)
Encoder encoder = EncoderFactory.createEncoder(params.get("spec"), colnames,
fo.getSchema(), (int)fo.getNumColumns(), meta);
mcOut.setDimension(mcIn.getRows()-((omap!=null)?omap.getNumRmRows():0), encoder.getNumCols());
Broadcast<Encoder> bmeta = sec.getSparkContext().broadcast(encoder);
Broadcast<TfOffsetMap> bomap = (omap!=null) ? sec.getSparkContext().broadcast(omap) : null;
//execute transform apply
JavaPairRDD<Long,FrameBlock> tmp = in
.mapToPair(new RDDTransformApplyFunction(bmeta, bomap));
JavaPairRDD<MatrixIndexes,MatrixBlock> out = FrameRDDConverterUtils
.binaryBlockToMatrixBlock(tmp, mcOut, mcOut);
//set output and maintain lineage/output characteristics
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), params.get("target"));
ec.releaseFrameInput(params.get("meta"));
}
else if ( opcode.equalsIgnoreCase("transformdecode") )
{
//get input RDD and meta data
JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
DataCharacteristics mc = sec.getDataCharacteristics(params.get("target"));
FrameBlock meta = sec.getFrameInput(params.get("meta"));
String[] colnames = meta.getColumnNames();
//reblock if necessary (clen > blen)
if( mc.getCols() > mc.getNumColBlocks() ) {
in = in.mapToPair(new RDDTransformDecodeExpandFunction(
(int)mc.getCols(), mc.getBlocksize()));
in = RDDAggregateUtils.mergeByKey(in, false);
}
//construct decoder and decode individual matrix blocks
Decoder decoder = DecoderFactory.createDecoder(params.get("spec"), colnames, null, meta);
JavaPairRDD<Long,FrameBlock> out = in.mapToPair(
new RDDTransformDecodeFunction(decoder, mc.getBlocksize()));
//set output and maintain lineage/output characteristics
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), params.get("target"));
ec.releaseFrameInput(params.get("meta"));
sec.getDataCharacteristics(output.getName()).set(
mc.getRows(), meta.getNumColumns(), mc.getBlocksize(), -1);
sec.getFrameObject(output.getName()).setSchema(decoder.getSchema());
}
else {
throw new DMLRuntimeException("Unknown parameterized builtin opcode: "+opcode);
}
}
public static class RDDReplaceFunction implements Function<MatrixBlock, MatrixBlock> {
private static final long serialVersionUID = 6576713401901671659L;
private double _pattern;
private double _replacement;
public RDDReplaceFunction(double pattern, double replacement) {
_pattern = pattern;
_replacement = replacement;
}
@Override
public MatrixBlock call(MatrixBlock arg0) {
return arg0.replaceOperations(new MatrixBlock(), _pattern, _replacement);
}
}
private static class RDDExtractTriangularFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = 2754868819184155702L;
private final boolean _lower, _diag, _values;
public RDDExtractTriangularFunction(boolean lower, boolean diag, boolean values) {
_lower = lower;
_diag = diag;
_values = values;
}
@Override
public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) {
return new ExtractTriangularIterator(arg0);
}
private class ExtractTriangularIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>>
{
public ExtractTriangularIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
super(in);
}
@Override
protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) {
MatrixIndexes ix = arg._1();
MatrixBlock mb = arg._2();
//handle cases of pass-through and reset block
if( (_lower && ix.getRowIndex() > ix.getColumnIndex())
|| (!_lower && ix.getRowIndex() < ix.getColumnIndex()) ) {
return _values ? arg : new Tuple2<>(
ix, new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), 1d));
}
//handle cases of empty blocks
if( (_lower && ix.getRowIndex() < ix.getColumnIndex())
|| (!_lower && ix.getRowIndex() > ix.getColumnIndex()) ) {
return new Tuple2<>(ix,
new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), true));
}
//extract triangular blocks for blocks on diagonal
assert(ix.getRowIndex() == ix.getColumnIndex());
return new Tuple2<>(ix,
mb.extractTriangular(new MatrixBlock(), _lower, _diag, _values));
}
}
}
public static class RDDRemoveEmptyFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock, MatrixBlock>>,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = 4906304771183325289L;
private final boolean _rmRows;
private final long _len;
private final long _blen;
public RDDRemoveEmptyFunction(boolean rmRows, long len, long blen) {
_rmRows = rmRows;
_len = len;
_blen = blen;
}
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg0)
throws Exception
{
//prepare inputs (for internal api compatibility)
IndexedMatrixValue data = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2()._1());
IndexedMatrixValue offsets = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2()._2());
//execute remove empty operations
ArrayList<IndexedMatrixValue> out = new ArrayList<>();
LibMatrixReorg.rmempty(data, offsets, _rmRows, _len, _blen, out);
//prepare and return outputs
return SparkUtils.fromIndexedMatrixBlock(out).iterator();
}
}
public static class RDDRemoveEmptyFunctionInMem implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = 4906304771183325289L;
private final boolean _rmRows;
private final long _len;
private final long _blen;
private PartitionedBroadcast<MatrixBlock> _off = null;
public RDDRemoveEmptyFunctionInMem(boolean rmRows, long len, long blen, PartitionedBroadcast<MatrixBlock> off) {
_rmRows = rmRows;
_len = len;
_blen = blen;
_off = off;
}
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0)
throws Exception
{
//prepare inputs (for internal api compatibility)
IndexedMatrixValue data = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2());
IndexedMatrixValue offsets = _rmRows ?
SparkUtils.toIndexedMatrixBlock(arg0._1(), _off.getBlock((int)arg0._1().getRowIndex(), 1)) :
SparkUtils.toIndexedMatrixBlock(arg0._1(), _off.getBlock(1, (int)arg0._1().getColumnIndex()));
//execute remove empty operations
ArrayList<IndexedMatrixValue> out = new ArrayList<>();
LibMatrixReorg.rmempty(data, offsets, _rmRows, _len, _blen, out);
//prepare and return outputs
return SparkUtils.fromIndexedMatrixBlock(out).iterator();
}
}
public static class RDDRExpandFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = -6153643261956222601L;
private double _maxVal;
private boolean _dirRows;
private boolean _cast;
private boolean _ignore;
private long _blen;
public RDDRExpandFunction(double maxVal, boolean dirRows, boolean cast, boolean ignore, long blen)
{
_maxVal = maxVal;
_dirRows = dirRows;
_cast = cast;
_ignore = ignore;
_blen = blen;
}
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0)
throws Exception
{
//prepare inputs (for internal api compatibility)
IndexedMatrixValue data = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2());
//execute rexpand operations
ArrayList<IndexedMatrixValue> out = new ArrayList<>();
LibMatrixReorg.rexpand(data, _maxVal, _dirRows, _cast, _ignore, _blen, out);
//prepare and return outputs
return SparkUtils.fromIndexedMatrixBlock(out).iterator();
}
}
public static class RDDMapGroupedAggFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = 6795402640178679851L;
private PartitionedBroadcast<MatrixBlock> _pbm = null;
private Operator _op = null;
private int _ngroups = -1;
private int _blen = -1;
public RDDMapGroupedAggFunction(PartitionedBroadcast<MatrixBlock> pbm, Operator op, int ngroups, int blen) {
_pbm = pbm;
_op = op;
_ngroups = ngroups;
_blen = blen;
}
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0)
throws Exception
{
//get all inputs
MatrixIndexes ix = arg0._1();
MatrixBlock target = arg0._2();
MatrixBlock groups = _pbm.getBlock((int)ix.getRowIndex(), 1);
//execute map grouped aggregate operations
IndexedMatrixValue in1 = SparkUtils.toIndexedMatrixBlock(ix, target);
ArrayList<IndexedMatrixValue> outlist = new ArrayList<>();
OperationsOnMatrixValues.performMapGroupedAggregate(_op, in1, groups, _ngroups, _blen, outlist);
//output all result blocks
return SparkUtils.fromIndexedMatrixBlock(outlist).iterator();
}
}
/**
* Similar to RDDMapGroupedAggFunction but single output block.
*/
public static class RDDMapGroupedAggFunction2 implements Function<Tuple2<MatrixIndexes,MatrixBlock>,MatrixBlock>
{
private static final long serialVersionUID = -6820599604299797661L;
private PartitionedBroadcast<MatrixBlock> _pbm = null;
private Operator _op = null;
private int _ngroups = -1;
public RDDMapGroupedAggFunction2(PartitionedBroadcast<MatrixBlock> pbm, Operator op, int ngroups) {
_pbm = pbm;
_op = op;
_ngroups = ngroups;
}
@Override
public MatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> arg0)
throws Exception
{
//get all inputs
MatrixIndexes ix = arg0._1();
MatrixBlock target = arg0._2();
MatrixBlock groups = _pbm.getBlock((int)ix.getRowIndex(), 1);
//execute map grouped aggregate operations
return groups.groupedAggOperations(target, null, new MatrixBlock(), _ngroups, _op);
}
}
public static class CreateMatrixCell implements Function<WeightedCell, MatrixCell>
{
private static final long serialVersionUID = -5783727852453040737L;
int blen; Operator op;
public CreateMatrixCell(int blen, Operator op) {
this.blen = blen;
this.op = op;
}
@Override
public MatrixCell call(WeightedCell kv)
throws Exception
{
double val = -1;
if(op instanceof CMOperator)
{
AggregateOperationTypes agg=((CMOperator)op).aggOpType;
switch(agg)
{
case COUNT:
val = kv.getWeight();
break;
case MEAN:
val = kv.getValue();
break;
case CM2:
val = kv.getValue()/ kv.getWeight();
break;
case CM3:
val = kv.getValue()/ kv.getWeight();
break;
case CM4:
val = kv.getValue()/ kv.getWeight();
break;
case VARIANCE:
val = kv.getValue()/kv.getWeight();
break;
default:
throw new DMLRuntimeException("Invalid aggreagte in CM_CV_Object: " + agg);
}
}
else
{
//avoid division by 0
val = kv.getValue()/kv.getWeight();
}
return new MatrixCell(val);
}
}
public static class RDDTransformApplyFunction implements PairFunction<Tuple2<Long,FrameBlock>,Long,FrameBlock>
{
private static final long serialVersionUID = 5759813006068230916L;
private Broadcast<Encoder> _bencoder = null;
private Broadcast<TfOffsetMap> _omap = null;
public RDDTransformApplyFunction(Broadcast<Encoder> bencoder, Broadcast<TfOffsetMap> omap) {
_bencoder = bencoder;
_omap = omap;
}
@Override
public Tuple2<Long,FrameBlock> call(Tuple2<Long, FrameBlock> in)
throws Exception
{
long key = in._1();
FrameBlock blk = in._2();
//execute block transform apply
Encoder encoder = _bencoder.getValue();
MatrixBlock tmp = encoder.apply(blk, new MatrixBlock(blk.getNumRows(), blk.getNumColumns(), false));
//remap keys
if( _omap != null ) {
key = _omap.getValue().getOffset(key);
}
//convert to frameblock to reuse frame-matrix reblock
return new Tuple2<>(key,
DataConverter.convertToFrameBlock(tmp));
}
}
public static class RDDTransformApplyOffsetFunction implements PairFunction<Tuple2<Long,FrameBlock>,Long,Long>
{
private static final long serialVersionUID = 3450977356721057440L;
private int[] _omitColList = null;
public RDDTransformApplyOffsetFunction(String spec, String[] colnames) {
try {
_omitColList = TfMetaUtils.parseJsonIDList(
spec, colnames, TfMethod.OMIT.toString());
}
catch (DMLRuntimeException e) {
throw new RuntimeException(e);
}
}
@Override
public Tuple2<Long,Long> call(Tuple2<Long, FrameBlock> in)
throws Exception
{
long key = in._1();
long rmRows = 0;
FrameBlock blk = in._2();
for( int i=0; i<blk.getNumRows(); i++ ) {
boolean valid = true;
for( int j=0; j<_omitColList.length; j++ ) {
int colID = _omitColList[j];
Object val = blk.get(i, colID-1);
valid &= !(val==null || (blk.getSchema()[colID-1]==
ValueType.STRING && val.toString().isEmpty()));
}
rmRows += valid ? 0 : 1;
}
return new Tuple2<>(key, rmRows);
}
}
public static class RDDTransformDecodeFunction implements PairFunction<Tuple2<MatrixIndexes,MatrixBlock>,Long,FrameBlock>
{
private static final long serialVersionUID = -4797324742568170756L;
private Decoder _decoder = null;
private int _blen = -1;
public RDDTransformDecodeFunction(Decoder decoder, int blen) {
_decoder = decoder;
_blen = blen;
}
@Override
public Tuple2<Long,FrameBlock> call(Tuple2<MatrixIndexes, MatrixBlock> in)
throws Exception
{
long rix = UtilFunctions.computeCellIndex(in._1().getRowIndex(), _blen, 0);
FrameBlock fbout = _decoder.decode(in._2(), new FrameBlock(_decoder.getSchema()));
fbout.setColumnNames(Arrays.copyOfRange(_decoder.getColnames(), 0, fbout.getNumColumns()));
return new Tuple2<>(rix, fbout);
}
}
public static class RDDTransformDecodeExpandFunction implements PairFunction<Tuple2<MatrixIndexes,MatrixBlock>,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = -8187400248076127598L;
private int _clen = -1;
private int _blen = -1;
public RDDTransformDecodeExpandFunction(int clen, int blen) {
_clen = clen;
_blen = blen;
}
@Override
public Tuple2<MatrixIndexes,MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> in)
throws Exception
{
MatrixIndexes inIx = in._1();
MatrixBlock inBlk = in._2();
//construct expanded block via leftindexing
int cl = (int)UtilFunctions.computeCellIndex(inIx.getColumnIndex(), _blen, 0)-1;
int cu = (int)UtilFunctions.computeCellIndex(inIx.getColumnIndex(), _blen, inBlk.getNumColumns()-1)-1;
MatrixBlock out = new MatrixBlock(inBlk.getNumRows(), _clen, false);
out = out.leftIndexingOperations(inBlk, 0, inBlk.getNumRows()-1, cl, cu, null, UpdateType.INPLACE_PINNED);
return new Tuple2<>(new MatrixIndexes(inIx.getRowIndex(), 1), out);
}
}
public void setOutputCharacteristicsForGroupedAgg(DataCharacteristics mc1, DataCharacteristics mcOut, JavaPairRDD<MatrixIndexes, MatrixCell> out) {
if(!mcOut.dimsKnown()) {
if(!mc1.dimsKnown()) {
throw new DMLRuntimeException("The output dimensions are not specified for grouped aggregate");
}
if ( params.get(Statement.GAGG_NUM_GROUPS) != null) {
int ngroups = (int) Double.parseDouble(params.get(Statement.GAGG_NUM_GROUPS));
mcOut.set(ngroups, mc1.getCols(), -1, -1); //grouped aggregate with cell output
}
else {
out = SparkUtils.cacheBinaryCellRDD(out);
mcOut.set(SparkUtils.computeDataCharacteristics(out));
mcOut.setBlocksize(-1); //grouped aggregate with cell output
}
}
}
}