blob: 76296a54f867c2bb171440b50ef41f7a3802279c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.spark;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.IsFrameBlockInRange;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.Iterator;
/**
* This class implements the frame indexing functionality inside Spark.
*
*/
public class FrameIndexingSPInstruction extends IndexingSPInstruction {
protected FrameIndexingSPInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl,
CPOperand cu, CPOperand out, SparkAggType aggtype, String opcode, String istr) {
super(in, rl, ru, cl, cu, out, aggtype, opcode, istr);
}
protected FrameIndexingSPInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl,
CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) {
super(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, istr);
}
@Override
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
String opcode = getOpcode();
//get indexing range
long rl = ec.getScalarInput(rowLower).getLongValue();
long ru = ec.getScalarInput(rowUpper).getLongValue();
long cl = ec.getScalarInput(colLower).getLongValue();
long cu = ec.getScalarInput(colUpper).getLongValue();
IndexRange ixrange = new IndexRange(rl, ru, cl, cu);
//right indexing
if( opcode.equalsIgnoreCase(RightIndex.OPCODE) )
{
//update and check output dimensions
DataCharacteristics mcIn = sec.getDataCharacteristics(input1.getName());
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
mcOut.set(ru-rl+1, cu-cl+1, mcIn.getBlocksize(), mcIn.getBlocksize());
checkValidOutputDimensions(mcOut);
//execute right indexing operation (partitioning-preserving if possible)
JavaPairRDD<Long,FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable( input1.getName() );
JavaPairRDD<Long,FrameBlock> out = null;
if( isPartitioningPreservingRightIndexing(mcIn, ixrange) ) {
out = in1.mapPartitionsToPair(
new SliceBlockPartitionFunction(ixrange, mcOut), true);
}
else{
out = in1.filter(new IsFrameBlockInRange(rl, ru, mcOut))
.mapToPair(new SliceBlock(ixrange, mcOut));
}
//put output RDD handle into symbol table
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), input1.getName());
//update schema of output with subset of input schema
sec.getFrameObject(output.getName()).setSchema(
sec.getFrameObject(input1.getName()).getSchema((int)cl, (int)cu));
}
//left indexing
else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE) || opcode.equalsIgnoreCase("mapLeftIndex"))
{
JavaPairRDD<Long,FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable( input1.getName() );
PartitionedBroadcast<FrameBlock> broadcastIn2 = null;
JavaPairRDD<Long,FrameBlock> in2 = null;
JavaPairRDD<Long,FrameBlock> out = null;
//update and check output dimensions
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
DataCharacteristics mcLeft = ec.getDataCharacteristics(input1.getName());
mcOut.set(mcLeft.getRows(), mcLeft.getCols(), mcLeft.getBlocksize(), mcLeft.getBlocksize());
checkValidOutputDimensions(mcOut);
//note: always frame rhs, scalars are preprocessed via cast to 1x1 frame
DataCharacteristics mcRight = ec.getDataCharacteristics(input2.getName());
//sanity check matching index range and rhs dimensions
if(!mcRight.dimsKnown()) {
throw new DMLRuntimeException("The right input frame dimensions are not specified for FrameIndexingSPInstruction");
}
if(!(ru-rl+1 == mcRight.getRows() && cu-cl+1 == mcRight.getCols())) {
throw new DMLRuntimeException("Invalid index range of leftindexing: ["+rl+":"+ru+","+cl+":"+cu+"] vs ["+mcRight.getRows()+"x"+mcRight.getCols()+"]." );
}
if(opcode.equalsIgnoreCase("mapLeftIndex"))
{
broadcastIn2 = sec.getBroadcastForFrameVariable( input2.getName());
//partitioning-preserving mappartitions (key access required for broadcast loopkup)
out = in1.mapPartitionsToPair(
new LeftIndexPartitionFunction(broadcastIn2, ixrange, mcOut), true);
}
else { //general case
// zero-out lhs
in1 = in1.flatMapToPair(new ZeroOutLHS(false, ixrange, mcLeft));
// slice rhs, shift and merge with lhs
in2 = sec.getFrameBinaryBlockRDDHandleForVariable( input2.getName() )
.flatMapToPair(new SliceRHSForLeftIndexing(ixrange, mcLeft));
out = FrameRDDAggregateUtils.mergeByKey(in1.union(in2));
}
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), input1.getName());
if( broadcastIn2 != null)
sec.addLineageBroadcast(output.getName(), input2.getName());
if(in2 != null)
sec.addLineageRDD(output.getName(), input2.getName());
}
else
throw new DMLRuntimeException("Invalid opcode (" + opcode +") encountered in FrameIndexingSPInstruction.");
}
private static boolean isPartitioningPreservingRightIndexing(DataCharacteristics mcIn, IndexRange ixrange) {
return ( mcIn.dimsKnown() &&
(ixrange.rowStart==1 && ixrange.rowEnd==mcIn.getRows() )); //Entire Column/s
}
private static void checkValidOutputDimensions(DataCharacteristics mcOut) {
if(!mcOut.dimsKnown()) {
throw new DMLRuntimeException("FrameIndexingSPInstruction: The updated output dimensions are invalid: " + mcOut);
}
}
private static class SliceRHSForLeftIndexing implements PairFlatMapFunction<Tuple2<Long,FrameBlock>, Long, FrameBlock>
{
private static final long serialVersionUID = 5724800998701216440L;
private IndexRange _ixrange = null;
private int _blen = -1;
private long _rlen = -1;
private long _clen = -1;
public SliceRHSForLeftIndexing(IndexRange ixrange, DataCharacteristics mcLeft) {
_ixrange = ixrange;
_rlen = mcLeft.getRows();
_clen = mcLeft.getCols();
_blen = (int) Math.min(OptimizerUtils.getDefaultFrameSize(), _rlen);
_blen = (int) mcLeft.getCols();
}
@Override
public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> rightKV)
throws Exception
{
Pair<Long,FrameBlock> in = SparkUtils.toIndexedFrameBlock(rightKV);
ArrayList<Pair<Long,FrameBlock>> out = new ArrayList<>();
OperationsOnMatrixValues.performShift(in, _ixrange, _blen, _rlen, _clen, out);
return SparkUtils.fromIndexedFrameBlock(out).iterator();
}
}
private static class ZeroOutLHS implements PairFlatMapFunction<Tuple2<Long,FrameBlock>, Long,FrameBlock>
{
private static final long serialVersionUID = -2672267231152496854L;
private boolean _complement = false;
private IndexRange _ixrange = null;
private int _blen = -1;
private long _rlen = -1;
public ZeroOutLHS(boolean complement, IndexRange range, DataCharacteristics mcLeft) {
_complement = complement;
_ixrange = range;
_blen = OptimizerUtils.getDefaultFrameSize();
_blen = (int) mcLeft.getCols();
_rlen = mcLeft.getRows();
}
@Override
public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> kv)
throws Exception
{
ArrayList<Pair<Long,FrameBlock>> out = new ArrayList<>();
IndexRange curBlockRange = new IndexRange(_ixrange.rowStart, _ixrange.rowEnd, _ixrange.colStart, _ixrange.colEnd);
// Global index of row (1-based)
long lGblStartRow = ((kv._1.longValue()-1)/_blen)*_blen+1;
FrameBlock zeroBlk = null;
int iMaxRowsToCopy = 0;
// Starting local location (0-based) of target block where to start copy.
int iRowStartDest = UtilFunctions.computeCellInBlock(kv._1, _blen);
for(int iRowStartSrc = 0; iRowStartSrc<kv._2.getNumRows(); iRowStartSrc += iMaxRowsToCopy, lGblStartRow += _blen) {
IndexRange range = UtilFunctions.getSelectedRangeForZeroOut(
new Pair<>(kv._1, kv._2), _blen, curBlockRange, lGblStartRow-1, lGblStartRow);
if(range.rowStart == -1 && range.rowEnd == -1 && range.colStart == -1 && range.colEnd == -1) {
throw new Exception("Error while getting range for zero-out");
}
//Maximum range of rows in target block
int iMaxRows=(int) Math.min(_blen, _rlen-lGblStartRow+1);
// Maximum number of rows to be copied from source block to target.
iMaxRowsToCopy = Math.min(iMaxRows, kv._2.getNumRows()-iRowStartSrc);
iMaxRowsToCopy = Math.min(iMaxRowsToCopy, iMaxRows-iRowStartDest);
// Zero out the applicable range in this block
zeroBlk = kv._2.zeroOutOperations(new FrameBlock(), range, _complement, iRowStartSrc, iRowStartDest, iMaxRows, iMaxRowsToCopy);
out.add(new Pair<>(lGblStartRow, zeroBlk));
curBlockRange.rowStart = lGblStartRow + _blen;
iRowStartDest = UtilFunctions.computeCellInBlock(iRowStartDest+iMaxRowsToCopy+1, _blen);
}
return SparkUtils.fromIndexedFrameBlock(out).iterator();
}
}
private static class LeftIndexPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<Long,FrameBlock>>, Long, FrameBlock>
{
private static final long serialVersionUID = -911940376947364915L;
private PartitionedBroadcast<FrameBlock> _binput;
private IndexRange _ixrange = null;
public LeftIndexPartitionFunction(PartitionedBroadcast<FrameBlock> binput, IndexRange ixrange, DataCharacteristics mc)
{
_binput = binput;
_ixrange = ixrange;
}
@Override
public LazyIterableIterator<Tuple2<Long, FrameBlock>> call(Iterator<Tuple2<Long, FrameBlock>> arg0)
throws Exception
{
return new LeftIndexPartitionIterator(arg0);
}
private class LeftIndexPartitionIterator extends LazyIterableIterator<Tuple2<Long, FrameBlock>>
{
public LeftIndexPartitionIterator(Iterator<Tuple2<Long, FrameBlock>> in) {
super(in);
}
@Override
protected Tuple2<Long, FrameBlock> computeNext(Tuple2<Long, FrameBlock> arg)
throws Exception
{
int iNumRowsInBlock = arg._2.getNumRows();
int iNumCols = arg._2.getNumColumns();
if(!UtilFunctions.isInFrameBlockRange(arg._1(), iNumRowsInBlock, _ixrange)) {
return arg;
}
// Calculate global index of left hand side block
long lhs_rl = Math.max(_ixrange.rowStart, arg._1); //Math.max(_ixrange.rowStart, (arg._1-1)*iNumRowsInBlock + 1);
long lhs_ru = Math.min(_ixrange.rowEnd, arg._1+iNumRowsInBlock-1);
long lhs_cl = Math.max(_ixrange.colStart, 1);
long lhs_cu = Math.min(_ixrange.colEnd, iNumCols);
// Calculate global index of right hand side block
long rhs_rl = lhs_rl - _ixrange.rowStart + 1;
long rhs_ru = rhs_rl + (lhs_ru - lhs_rl);
long rhs_cl = lhs_cl - _ixrange.colStart + 1;
long rhs_cu = rhs_cl + (lhs_cu - lhs_cl);
// Provide local zero-based index to leftIndexingOperations
int lhs_lrl = (int)(lhs_rl- arg._1);
int lhs_lru = (int)(lhs_ru- arg._1);
int lhs_lcl = (int)lhs_cl-1;
int lhs_lcu = (int)lhs_cu-1;
FrameBlock ret = arg._2;
int blen = OptimizerUtils.DEFAULT_BLOCKSIZE;
long rhs_rl_pb = rhs_rl;
long rhs_ru_pb = Math.min(rhs_ru, (((rhs_rl-1)/blen)+1)*blen);
while(rhs_rl_pb <= rhs_ru_pb) {
// Provide global zero-based index to sliceOperations, but only for one RHS partition block at a time.
FrameBlock slicedRHSMatBlock = _binput.slice(rhs_rl_pb, rhs_ru_pb, rhs_cl, rhs_cu, new FrameBlock());
// Provide local zero-based index to leftIndexingOperations
int lhs_lrl_pb = (int) (lhs_lrl + (rhs_rl_pb - rhs_rl));
int lhs_lru_pb = (int) (lhs_lru + (rhs_ru_pb - rhs_ru));
ret = ret.leftIndexingOperations(slicedRHSMatBlock, lhs_lrl_pb, lhs_lru_pb, lhs_lcl, lhs_lcu, new FrameBlock());
rhs_rl_pb = rhs_ru_pb + 1;
rhs_ru_pb = Math.min(rhs_ru, rhs_ru_pb+blen);
}
return new Tuple2<>(arg._1, ret);
}
}
}
private static class SliceBlock implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock>
{
private static final long serialVersionUID = -5270171193018691692L;
private IndexRange _ixrange;
public SliceBlock(IndexRange ixrange, DataCharacteristics mcOut) {
_ixrange = ixrange;
}
@Override
public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> kv)
throws Exception
{
long rowindex = kv._1();
FrameBlock in = kv._2();
//prepare local index range (block guaranteed to be in range)
int rl = (int) ((rowindex > _ixrange.rowStart) ? 0 : _ixrange.rowStart-rowindex);
int ru = (int) ((_ixrange.rowEnd-rowindex >= in.getNumRows()) ?
in.getNumRows()-1 : _ixrange.rowEnd-rowindex);
//slice out the block
FrameBlock out = in.slice(rl, ru, (int)(_ixrange.colStart-1),
(int)(_ixrange.colEnd-1), new FrameBlock());
//return block with shifted row index
long rowindex2 = (rowindex > _ixrange.rowStart) ? rowindex-_ixrange.rowStart+1 : 1;
return new Tuple2<>(rowindex2, out);
}
}
private static class SliceBlockPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Long, FrameBlock>
{
private static final long serialVersionUID = -1655390518299307588L;
private IndexRange _ixrange;
public SliceBlockPartitionFunction(IndexRange ixrange, DataCharacteristics mcOut) {
_ixrange = ixrange;
}
@Override
public LazyIterableIterator<Tuple2<Long, FrameBlock>> call(Iterator<Tuple2<Long, FrameBlock>> arg0)
throws Exception
{
return new SliceBlockPartitionIterator(arg0);
}
/**
* NOTE: this function is only applied for slicing columns (which preserved all rows
* and hence the existing partitioning).
*/
private class SliceBlockPartitionIterator extends LazyIterableIterator<Tuple2<Long, FrameBlock>>
{
public SliceBlockPartitionIterator(Iterator<Tuple2<Long, FrameBlock>> in) {
super(in);
}
@Override
protected Tuple2<Long, FrameBlock> computeNext(Tuple2<Long, FrameBlock> arg)
throws Exception
{
long rowindex = arg._1();
FrameBlock in = arg._2();
//slice out the block
FrameBlock out = in.slice(0, in.getNumRows()-1,
(int)_ixrange.colStart-1, (int)_ixrange.colEnd-1, new FrameBlock());
//return block with shifted row index
return new Tuple2<>(rowindex, out);
}
}
}
}