| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.instructions.spark; |
| |
| import org.apache.spark.api.java.JavaPairRDD; |
| import org.apache.spark.api.java.function.PairFlatMapFunction; |
| import org.apache.spark.broadcast.Broadcast; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.common.Types.FileFormat; |
| import org.apache.sysds.common.Types.ValueType; |
| import org.apache.sysds.runtime.DMLRuntimeException; |
| import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; |
| import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; |
| import org.apache.sysds.runtime.functionobjects.SwapIndex; |
| import org.apache.sysds.runtime.instructions.InstructionUtils; |
| import org.apache.sysds.runtime.instructions.cp.CPOperand; |
| import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator; |
| import org.apache.sysds.runtime.instructions.spark.functions.ExtractBlockForBinaryReblock; |
| import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils; |
| import org.apache.sysds.runtime.matrix.data.DnnParameters; |
| import org.apache.sysds.runtime.matrix.data.LibMatrixDNN; |
| import org.apache.sysds.runtime.matrix.data.LibMatrixDNN.PoolingType; |
| import org.apache.sysds.runtime.matrix.data.LibMatrixNative; |
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| import org.apache.sysds.runtime.matrix.data.MatrixIndexes; |
| import org.apache.sysds.runtime.matrix.operators.ReorgOperator; |
| import org.apache.sysds.runtime.meta.DataCharacteristics; |
| import org.apache.sysds.runtime.meta.MatrixCharacteristics; |
| import org.apache.sysds.runtime.meta.MetaDataFormat; |
| import org.apache.sysds.runtime.util.DnnUtils; |
| import org.apache.sysds.utils.NativeHelper; |
| import scala.Tuple2; |
| |
| import java.util.ArrayList; |
| import java.util.Iterator; |
| |
| public class DnnSPInstruction extends UnarySPInstruction { |
| private CPOperand _in2; |
| private CPOperand _in3; |
| private ArrayList<CPOperand> _input_shape; |
| private ArrayList<CPOperand> _filter_shape; |
| private ArrayList<CPOperand> _stride = new ArrayList<>(); |
| private ArrayList<CPOperand> _padding = new ArrayList<>(); |
| |
| private DnnSPInstruction(CPOperand in, CPOperand out, String opcode, String istr, |
| ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, |
| ArrayList<CPOperand> filter_shape) { |
| super(SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); |
| _stride = stride; |
| _padding = padding; |
| _input_shape = input_shape; |
| _filter_shape = filter_shape; |
| } |
| |
| private DnnSPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, |
| ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, |
| ArrayList<CPOperand> filter_shape) { |
| super(SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); |
| _in2 = in2; |
| _stride = stride; |
| _padding = padding; |
| _input_shape = input_shape; |
| _filter_shape = filter_shape; |
| } |
| |
| private DnnSPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, |
| String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, |
| ArrayList<CPOperand> filter_shape) { |
| super(SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); |
| _in2 = in2; |
| _in3 = in3; |
| _stride = stride; |
| _padding = padding; |
| _input_shape = input_shape; |
| _filter_shape = filter_shape; |
| } |
| |
| private DnnSPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr) { |
| super(SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); |
| _in2 = in2; |
| } |
| |
| public static DnnSPInstruction parseInstruction( String str ) { |
| CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); |
| CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); |
| |
| String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); |
| String opcode = parts[0]; |
| if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling")) { |
| InstructionUtils.checkNumFields(parts, 14); |
| // stride1, stride2, padding1, padding2 |
| // input_shape1, input_shape2, input_shape3, input_shape4, |
| // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k |
| in.split(parts[1]); |
| out.split(parts[14]); |
| |
| ArrayList<CPOperand> stride = new ArrayList<>(); |
| ArrayList<CPOperand> padding = new ArrayList<>(); |
| ArrayList<CPOperand> input_shape = new ArrayList<>(); |
| ArrayList<CPOperand> filter_shape = new ArrayList<>(); |
| stride.add(new CPOperand(parts[2])); |
| stride.add(new CPOperand(parts[3])); |
| padding.add(new CPOperand(parts[4])); |
| padding.add(new CPOperand(parts[5])); |
| input_shape.add(new CPOperand(parts[6])); |
| input_shape.add(new CPOperand(parts[7])); |
| input_shape.add(new CPOperand(parts[8])); |
| input_shape.add(new CPOperand(parts[9])); |
| filter_shape.add(new CPOperand(parts[10])); |
| filter_shape.add(new CPOperand(parts[11])); |
| filter_shape.add(new CPOperand(parts[12])); |
| filter_shape.add(new CPOperand(parts[13])); |
| |
| return new DnnSPInstruction(in, out, opcode, str, stride, |
| padding, input_shape, filter_shape); |
| } |
| else if (opcode.equalsIgnoreCase("maxpooling_backward") |
| || opcode.equalsIgnoreCase("conv2d") |
| || opcode.equalsIgnoreCase("conv2d_backward_filter") |
| || opcode.equalsIgnoreCase("conv2d_backward_data")) { |
| InstructionUtils.checkNumFields(parts, 15); |
| // dout, stride1, stride2, padding1, padding2 |
| // input_shape1, input_shape2, input_shape3, input_shape4, |
| // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k |
| in.split(parts[1]); |
| CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); |
| in2.split(parts[2]); |
| out.split(parts[15]); |
| |
| ArrayList<CPOperand> stride = new ArrayList<>(); |
| ArrayList<CPOperand> padding = new ArrayList<>(); |
| ArrayList<CPOperand> input_shape = new ArrayList<>(); |
| ArrayList<CPOperand> filter_shape = new ArrayList<>(); |
| stride.add(new CPOperand(parts[3])); |
| stride.add(new CPOperand(parts[4])); |
| padding.add(new CPOperand(parts[5])); |
| padding.add(new CPOperand(parts[6])); |
| input_shape.add(new CPOperand(parts[7])); |
| input_shape.add(new CPOperand(parts[8])); |
| input_shape.add(new CPOperand(parts[9])); |
| input_shape.add(new CPOperand(parts[10])); |
| filter_shape.add(new CPOperand(parts[11])); |
| filter_shape.add(new CPOperand(parts[12])); |
| filter_shape.add(new CPOperand(parts[13])); |
| filter_shape.add(new CPOperand(parts[14])); |
| |
| return new DnnSPInstruction(in, in2, out, opcode, str, stride, |
| padding, input_shape, filter_shape); |
| } |
| else if (opcode.equalsIgnoreCase("conv2d_bias_add")) { |
| InstructionUtils.checkNumFields(parts, 16); |
| // dout, stride1, stride2, padding1, padding2 |
| // input_shape1, input_shape2, input_shape3, input_shape4, |
| // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k |
| in.split(parts[1]); |
| CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); |
| in2.split(parts[2]); |
| CPOperand in3 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); |
| in3.split(parts[3]); |
| out.split(parts[16]); |
| |
| ArrayList<CPOperand> stride = new ArrayList<>(); |
| ArrayList<CPOperand> padding = new ArrayList<>(); |
| ArrayList<CPOperand> input_shape = new ArrayList<>(); |
| ArrayList<CPOperand> filter_shape = new ArrayList<>(); |
| stride.add(new CPOperand(parts[4])); |
| stride.add(new CPOperand(parts[5])); |
| padding.add(new CPOperand(parts[6])); |
| padding.add(new CPOperand(parts[7])); |
| input_shape.add(new CPOperand(parts[8])); |
| input_shape.add(new CPOperand(parts[9])); |
| input_shape.add(new CPOperand(parts[10])); |
| input_shape.add(new CPOperand(parts[11])); |
| filter_shape.add(new CPOperand(parts[12])); |
| filter_shape.add(new CPOperand(parts[13])); |
| filter_shape.add(new CPOperand(parts[14])); |
| filter_shape.add(new CPOperand(parts[15])); |
| |
| return new DnnSPInstruction(in, in2, in3, out, opcode, str, stride, |
| padding, input_shape, filter_shape); |
| } |
| else if (opcode.equalsIgnoreCase("bias_add")) { |
| InstructionUtils.checkNumFields(parts, 3); |
| in.split(parts[1]); |
| CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); |
| in2.split(parts[2]); |
| out.split(parts[3]); |
| return new DnnSPInstruction(in, in2, out, opcode, str); |
| } |
| else { |
| throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str); |
| } |
| } |
| |
| private static JavaPairRDD<MatrixIndexes,MatrixBlock> reblockAsRectangularMatrices(SparkExecutionContext sec, String name, int numRowsPerBlock) { |
| JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable( name ); |
| DataCharacteristics mcRdd = sec.getDataCharacteristics(name); |
| if( mcRdd.getBlocksize() != 1) { |
| DataCharacteristics mcOut = new MatrixCharacteristics(mcRdd); |
| mcOut.setBlocksize(numRowsPerBlock); |
| in1 = RDDAggregateUtils.mergeByKey(in1.flatMapToPair(new ExtractBlockForBinaryReblock(mcRdd, mcOut))); |
| // TODO: Inject checkpoint to avoid doing this repeated for validation set |
| // sec.setRDDHandleForVariable(name, in1); |
| // sec.setMetaData(name, new MatrixDimensionsMetaData(mcOut)); |
| } |
| return in1; |
| } |
| |
| private static Broadcast<MatrixBlock> getBroadcast(SparkExecutionContext sec, String name) { |
| MatrixBlock mb = sec.getMatrixInput(name); |
| sec.releaseMatrixInput(name); |
| return sec.getSparkContext().broadcast(mb); |
| } |
| |
| @Override |
| public void processInstruction(ExecutionContext ec) { |
| SparkExecutionContext sec = (SparkExecutionContext)ec; |
| if(instOpcode.equalsIgnoreCase("conv2d") || instOpcode.equalsIgnoreCase("conv2d_bias_add") |
| || instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { |
| String rddVar = input1.getName(); |
| int numRowsPerBlock = 1; |
| JavaPairRDD<MatrixIndexes,MatrixBlock> inputRDD = reblockAsRectangularMatrices(sec, rddVar, numRowsPerBlock); |
| DataCharacteristics mcRdd = sec.getDataCharacteristics(rddVar); |
| |
| // ------------------------------------ |
| // TODO: Handle large filters > 2G |
| Broadcast<MatrixBlock> filterBroadcast = null; |
| Broadcast<MatrixBlock> biasBroadcast = null; |
| if(instOpcode.equalsIgnoreCase("conv2d")) { |
| filterBroadcast = getBroadcast(sec, _in2.getName()); |
| } |
| else if(instOpcode.equalsIgnoreCase("conv2d_bias_add")) { |
| filterBroadcast = getBroadcast(sec, _in3.getName()); |
| biasBroadcast = getBroadcast(sec, _in2.getName()); |
| } |
| // ------------------------------------ |
| |
| int pad_h = getScalarInput(ec, _padding, 0); |
| int pad_w = getScalarInput(ec, _padding, 1); |
| int stride_h = getScalarInput(ec, _stride, 0); |
| int stride_w = getScalarInput(ec, _stride, 1); |
| |
| // int N = getScalarInput(ec, _input_shape, 0); |
| int C = getScalarInput(ec, _input_shape, 1); |
| int H = getScalarInput(ec, _input_shape, 2); |
| int W = getScalarInput(ec, _input_shape, 3); |
| |
| int K = getScalarInput(ec, _filter_shape, 0); |
| int R = getScalarInput(ec, _filter_shape, 2); |
| int S = getScalarInput(ec, _filter_shape, 3); |
| int P = (int) DnnUtils.getP(H, R, stride_h, pad_h); |
| int Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w); |
| |
| DnnParameters params = new DnnParameters(numRowsPerBlock, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, 1); |
| boolean enableNativeBLAS = NativeHelper.isNativeLibraryLoaded(); |
| JavaPairRDD<MatrixIndexes,MatrixBlock> out = inputRDD.mapPartitionsToPair(new RDDConv2dMapMMFunction(filterBroadcast, params, instOpcode, biasBroadcast, mcRdd.getRows(), enableNativeBLAS), true); |
| |
| //put output RDD handle into symbol table |
| sec.setRDDHandleForVariable(output.getName(), out); |
| sec.addLineageRDD(output.getName(), rddVar); |
| |
| long nnz = -1; // TODO: Handle nnz |
| long numCols = ((long)K)*((long)P)*Q; |
| if(instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { |
| numCols = ((long)C)*((long)P)*Q; |
| } |
| if(numCols > Integer.MAX_VALUE) { |
| throw new DMLRuntimeException("The current operator doesnot support large outputs."); |
| } |
| sec.setMetaData(output.getName(), |
| new MetaDataFormat(new MatrixCharacteristics(mcRdd.getRows(), numCols, numRowsPerBlock, nnz), FileFormat.BINARY)); |
| } |
| else { |
| throw new DMLRuntimeException("Not implemented: " + instOpcode); |
| } |
| } |
| |
| private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) { |
| return (int) ec.getScalarInput(aL.get(index)).getLongValue(); |
| } |
| |
| private static class RDDConv2dMapMMFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> { |
| // PairFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, MatrixBlock> { |
| private static final long serialVersionUID = -2106155380020232155L; |
| Broadcast<MatrixBlock> filterBroadcast = null; |
| Broadcast<MatrixBlock> biasBroadcast = null; |
| DnnParameters params = null; |
| String instOpcode = null; boolean enableNative; |
| long numRows = 0; |
| public RDDConv2dMapMMFunction(Broadcast<MatrixBlock> filterBroadcast, |
| DnnParameters params, String instOpcode, Broadcast<MatrixBlock> biasBroadcast, long numRows, boolean enableNativeBLAS) { |
| this.filterBroadcast = filterBroadcast; |
| this.params = params; |
| this.instOpcode = instOpcode; |
| this.biasBroadcast = biasBroadcast; |
| this.numRows = numRows; |
| this.enableNative = enableNativeBLAS; |
| } |
| |
| private MatrixBlock processRectangularBlock(MatrixBlock matBlock) throws Exception { |
| MatrixBlock outputBlock = null; |
| if(instOpcode.equalsIgnoreCase("conv2d")) { |
| MatrixBlock filter = filterBroadcast.getValue(); |
| if(filter.isEmptyBlock() || matBlock.isEmptyBlock()) { |
| outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, true); |
| } |
| else { |
| outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, false).allocateDenseBlock(); |
| if(enableNative) |
| LibMatrixNative.conv2d(matBlock, filter, outputBlock, params); |
| else |
| LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); |
| } |
| } |
| else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) { |
| MatrixBlock filter = filterBroadcast.getValue(); |
| MatrixBlock bias = biasBroadcast.getValue(); |
| if((filter.isEmptyBlock() || matBlock.isEmptyBlock()) && bias.isEmptyBlock()) { |
| outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, true); |
| } |
| else { |
| outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, false).allocateDenseBlock(); |
| if(!bias.isEmptyBlock()) |
| params.bias = bias; |
| if(enableNative) |
| LibMatrixNative.conv2d(matBlock, filter, outputBlock, params); |
| else |
| LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); |
| } |
| } |
| else if(instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { |
| if(matBlock.isEmptyBlock()) { |
| outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, true); |
| } |
| else { |
| outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, false).allocateBlock(); |
| if(instOpcode.equalsIgnoreCase("maxpooling")) |
| outputBlock.getDenseBlock().set(-Double.MAX_VALUE); |
| LibMatrixDNN.pooling(matBlock, outputBlock, params, PoolingType.MAX); |
| } |
| } |
| else if(instOpcode.equalsIgnoreCase("avgpooling") || instOpcode.equalsIgnoreCase("relu_avgpooling")) { |
| if(matBlock.isEmptyBlock()) { |
| outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, true); |
| } |
| else { |
| outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, false).allocateBlock(); |
| LibMatrixDNN.pooling(matBlock, outputBlock, params, PoolingType.AVG); |
| } |
| } |
| else { |
| throw new RuntimeException("Not implemented"); |
| } |
| return outputBlock; |
| } |
| |
| @Override |
| public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( |
| Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) |
| throws Exception { |
| return new MapsideDnnPartitionIterator(arg0); |
| } |
| |
| // Avoid materialization of partitions |
| private class MapsideDnnPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> { |
| public MapsideDnnPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) { |
| super(in); |
| } |
| |
| @Override |
| protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception { |
| if(arg._1.getRowIndex() > numRows || arg._1.getColumnIndex() != 1) { |
| throw new RuntimeException("Expected the inputs to be reblocked as rectangular RDD"); |
| } |
| MatrixBlock out = processRectangularBlock(arg._2); |
| if(out.getNumRows() != 1) |
| throw new RuntimeException("Expected the output to have 1 row"); |
| return new Tuple2<>(arg._1, out); |
| } |
| } |
| |
| } |
| } |