blob: cecbd3ddf46ede9e9ea90a6a0a0d24c4e5cf359a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.spark;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.CorrectionLocationType;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.data.BasicTensorBlock;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.functions.AggregateDropCorrectionFunction;
import org.apache.sysds.runtime.instructions.spark.functions.FilterDiagMatrixBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import scala.Tuple2;
public class AggregateUnarySPInstruction extends UnarySPInstruction {
private SparkAggType _aggtype = null;
private AggregateOperator _aop = null;
protected AggregateUnarySPInstruction(SPType type, AggregateUnaryOperator auop, AggregateOperator aop, CPOperand in,
CPOperand out, SparkAggType aggtype, String opcode, String istr) {
super(type, auop, in, out, opcode, istr);
_aggtype = aggtype;
_aop = aop;
}
public static AggregateUnarySPInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 3);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
SparkAggType aggtype = SparkAggType.valueOf(parts[3]);
String aopcode = InstructionUtils.deriveAggregateOperatorOpcode(opcode);
CorrectionLocationType corrLoc = InstructionUtils.deriveAggregateOperatorCorrectionLocation(opcode);
AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
AggregateOperator aop = InstructionUtils.parseAggregateOperator(aopcode, corrLoc.toString());
return new AggregateUnarySPInstruction(SPType.AggregateUnary, aggun, aop, in1, out, aggtype, opcode, str);
}
@Override
public void processInstruction( ExecutionContext ec ) {
if (input1.getDataType() == Types.DataType.MATRIX) {
processMatrixAggregate(ec);
} else {
processTensorAggregate(ec);
}
}
private void processMatrixAggregate(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
DataCharacteristics mc = sec.getDataCharacteristics(input1.getName());
//get input
JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
JavaPairRDD<MatrixIndexes,MatrixBlock> out = in;
//filter input blocks for trace
if( getOpcode().equalsIgnoreCase("uaktrace") )
out = out.filter(new FilterDiagMatrixBlocksFunction());
//execute unary aggregate operation
AggregateUnaryOperator auop = (AggregateUnaryOperator)_optr;
AggregateOperator aggop = _aop;
//perform aggregation if necessary and put output into symbol table
if( _aggtype == SparkAggType.SINGLE_BLOCK )
{
if( auop.sparseSafe )
out = out.filter(new FilterNonEmptyBlocksFunction());
JavaRDD<MatrixBlock> out2 = out.map(
new RDDUAggFunction2(auop, mc.getBlocksize()));
MatrixBlock out3 = RDDAggregateUtils.aggStable(out2, aggop);
//drop correction after aggregation
out3.dropLastRowsOrColumns(aggop.correction);
//put output block into symbol table (no lineage because single block)
//this also includes implicit maintenance of matrix characteristics
sec.setMatrixOutput(output.getName(), out3);
}
else //MULTI_BLOCK or NONE
{
if( _aggtype == SparkAggType.NONE ) {
//in case of no block aggregation, we always drop the correction as well as
//use a partitioning-preserving mapvalues
out = out.mapValues(new RDDUAggValueFunction(auop, mc.getBlocksize()));
}
else if( _aggtype == SparkAggType.MULTI_BLOCK ) {
//in case of multi-block aggregation, we always keep the correction
out = out.mapToPair(new RDDUAggFunction(auop, mc.getBlocksize()));
out = RDDAggregateUtils.aggByKeyStable(out, aggop, false);
//drop correction after aggregation if required (aggbykey creates
//partitioning, drop correction via partitioning-preserving mapvalues)
if( auop.aggOp.existsCorrection() )
out = out.mapValues( new AggregateDropCorrectionFunction(aggop) );
}
//put output RDD handle into symbol table
updateUnaryAggOutputDataCharacteristics(sec, auop.indexFn);
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), input1.getName());
}
}
private void processTensorAggregate(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
//get input
// TODO support DataTensor
JavaPairRDD<TensorIndexes, TensorBlock> in = sec.getBinaryTensorBlockRDDHandleForVariable( input1.getName() );
JavaPairRDD<TensorIndexes, TensorBlock> out = in;
// TODO: filter input blocks for trace
//execute unary aggregate operation
AggregateUnaryOperator auop = (AggregateUnaryOperator)_optr;
AggregateOperator aggop = _aop;
//perform aggregation if necessary and put output into symbol table
if( _aggtype == SparkAggType.SINGLE_BLOCK )
{
// TODO filter non empty blocks if sparse safe
JavaRDD<TensorBlock> out2 = out.map(new RDDUTensorAggFunction2(auop));
TensorBlock out3 = RDDAggregateUtils.aggStableTensor(out2, aggop);
//put output block into symbol table (no lineage because single block)
//this also includes implicit maintenance of data characteristics
// TODO generalize to drop depending on location of correction
// TODO support DataTensor
TensorBlock out4 = new TensorBlock(out3.getValueType(), new int[]{1, 1});
out4.set(0, 0, out3.get(0, 0));
sec.setTensorOutput(output.getName(), out4);
}
else //MULTI_BLOCK or NONE
{
if( _aggtype == SparkAggType.NONE ) {
//in case of no block aggregation, we always drop the correction as well as
//use a partitioning-preserving mapvalues
out = out.mapValues(new RDDUTensorAggValueFunction(auop));
}
else if( _aggtype == SparkAggType.MULTI_BLOCK ) {
// TODO MULTI_BLOCK
throw new DMLRuntimeException("Multi block spark aggregations are not supported for tensors yet.");
/*
//in case of multi-block aggregation, we always keep the correction
out = out.mapToPair(new RDDUTensorAggFunction(auop, dc.getBlocksize(), dc.getBlocksize()));
out = RDDAggregateUtils.aggByKeyStable(out, aggop, false);
//drop correction after aggregation if required (aggbykey creates
//partitioning, drop correction via partitioning-preserving mapvalues)
if( auop.aggOp.correctionExists )
out = out.mapValues( new AggregateDropCorrectionFunction(aggop) );
*/
}
//put output RDD handle into symbol table
updateUnaryAggOutputDataCharacteristics(sec, auop.indexFn);
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), input1.getName());
}
}
private static class RDDUAggFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = 2672082409287856038L;
private AggregateUnaryOperator _op = null;
private int _blen = -1;
public RDDUAggFunction( AggregateUnaryOperator op, int blen ) {
_op = op;
_blen = blen;
}
@Override
public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 )
throws Exception
{
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn = arg0._2();
MatrixIndexes ixOut = new MatrixIndexes();
MatrixBlock blkOut = new MatrixBlock();
//unary aggregate operation (always keep the correction)
OperationsOnMatrixValues.performAggregateUnary(
ixIn, blkIn, ixOut, blkOut, _op, _blen);
//output new tuple
return new Tuple2<>(ixOut, blkOut);
}
}
/**
* Similar to RDDUAggFunction but single output block.
*/
public static class RDDUAggFunction2 implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock>
{
private static final long serialVersionUID = 2672082409287856038L;
private AggregateUnaryOperator _op = null;
private int _blen = -1;
public RDDUAggFunction2( AggregateUnaryOperator op, int blen ) {
_op = op;
_blen = blen;
_blen = blen;
}
@Override
public MatrixBlock call( Tuple2<MatrixIndexes, MatrixBlock> arg0 )
throws Exception
{
//unary aggregate operation (always keep the correction)
return arg0._2.aggregateUnaryOperations(
_op, new MatrixBlock(), _blen, arg0._1());
}
}
/**
* Similar to RDDUAggFunction but single output block.
*/
public static class RDDUTensorAggFunction2 implements Function<Tuple2<TensorIndexes, TensorBlock>, TensorBlock>
{
private static final long serialVersionUID = -6258769067791011763L;
private AggregateUnaryOperator _op = null;
public RDDUTensorAggFunction2( AggregateUnaryOperator op ) {
_op = op;
}
@Override
public TensorBlock call(Tuple2<TensorIndexes, TensorBlock> arg0 )
throws Exception
{
//unary aggregate operation (always keep the correction)
// TODO support DataTensor
return new TensorBlock(arg0._2.getBasicTensor().aggregateUnaryOperations(_op, new BasicTensorBlock()));
}
}
private static class RDDUAggValueFunction implements Function<MatrixBlock, MatrixBlock>
{
private static final long serialVersionUID = 5352374590399929673L;
private AggregateUnaryOperator _op = null;
private int _blen = -1;
private MatrixIndexes _ix = null;
public RDDUAggValueFunction( AggregateUnaryOperator op, int blen ) {
_op = op;
_blen = blen;
_blen = blen;
_ix = new MatrixIndexes(1,1);
}
@Override
public MatrixBlock call( MatrixBlock arg0 )
throws Exception
{
MatrixBlock blkOut = new MatrixBlock();
//unary aggregate operation
arg0.aggregateUnaryOperations(_op, blkOut, _blen, _ix);
//always drop correction since no aggregation
blkOut.dropLastRowsOrColumns(_op.aggOp.correction);
//output new tuple
return blkOut;
}
}
private static class RDDUTensorAggValueFunction implements Function<TensorBlock, TensorBlock>
{
private static final long serialVersionUID = -968274963539513423L;
private AggregateUnaryOperator _op = null;
public RDDUTensorAggValueFunction(AggregateUnaryOperator op)
{
_op = op;
}
@Override
public TensorBlock call(TensorBlock arg0 )
throws Exception
{
// TODO support DataTensor
BasicTensorBlock blkOut = new BasicTensorBlock();
//unary aggregate operation
arg0.getBasicTensor().aggregateUnaryOperations(_op, blkOut);
//always drop correction since no aggregation
// TODO generalize to drop depending on location of correction
TensorBlock out = new TensorBlock(blkOut.getValueType(), new int[]{1, 1});
out.set(0, 0, blkOut.get(0, 0));
//output new tuple
return out;
}
}
}