blob: 69661787ff803ae973bdbb74d602a646eb3783bb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.spark;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import scala.Tuple2;
public class BinaryFrameFrameSPInstruction extends BinarySPInstruction {
protected BinaryFrameFrameSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
super(SPType.Binary, op, in1, in2, out, opcode, istr);
}
public static BinarySPInstruction parseInstruction ( String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields (parts, 3);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
Types.DataType dt1 = in1.getDataType();
Types.DataType dt2 = in2.getDataType();
Operator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
if(dt1 == Types.DataType.FRAME && dt2 == Types.DataType.FRAME)
return new BinaryFrameFrameSPInstruction(operator, in1, in2, out, opcode, str);
else
throw new DMLRuntimeException("Frame binary operation not yet implemented for frame-scalar, or frame-matrix");
}
@Override
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
// Get input RDDs
JavaPairRDD<Long, FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName());
JavaPairRDD<Long, FrameBlock> out = null;
if(getOpcode().equals("dropInvalidType")) {
// get schema frame-block
Broadcast<FrameBlock> fb = sec.getSparkContext().broadcast(sec.getFrameInput(input2.getName()));
out = in1.mapValues(new isCorrectbySchema(fb.getValue()));
//release input frame
sec.releaseFrameInput(input2.getName());
}
else {
JavaPairRDD<Long, FrameBlock> in2 = sec.getFrameBinaryBlockRDDHandleForVariable(input2.getName());
// create output frame
BinaryOperator dop = (BinaryOperator) _optr;
// check for binary operations
out = in1.join(in2).mapValues(new FrameComparison(dop));
}
//set output RDD and maintain dependencies
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), input1.getName());
if( !getOpcode().equals("dropInvalidType") )
sec.addLineageRDD(output.getName(), input2.getName());
}
private static class isCorrectbySchema implements Function<FrameBlock,FrameBlock> {
private static final long serialVersionUID = 5850400295183766400L;
private FrameBlock schema_frame = null;
public isCorrectbySchema(FrameBlock fb_name ) {
schema_frame = fb_name;
}
@Override
public FrameBlock call(FrameBlock arg0) throws Exception {
return arg0.dropInvalidType(schema_frame);
}
}
private static class FrameComparison implements Function<Tuple2<FrameBlock, FrameBlock>, FrameBlock> {
private static final long serialVersionUID = 5850400295183766401L;
private final BinaryOperator bop;
public FrameComparison(BinaryOperator op){
bop = op;
}
@Override
public FrameBlock call(Tuple2<FrameBlock, FrameBlock> arg0) throws Exception {
return arg0._1().binaryOperations(bop, arg0._2(), null);
}
}
}