| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops.cost; |
| |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.sysds.conf.ConfigurationManager; |
| import org.apache.sysds.hops.OptimizerUtils; |
| import org.apache.sysds.lops.Lop; |
| import org.apache.sysds.parser.DMLProgram; |
| import org.apache.sysds.runtime.controlprogram.BasicProgramBlock; |
| import org.apache.sysds.runtime.controlprogram.ForProgramBlock; |
| import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock; |
| import org.apache.sysds.runtime.controlprogram.IfProgramBlock; |
| import org.apache.sysds.runtime.controlprogram.LocalVariableMap; |
| import org.apache.sysds.runtime.controlprogram.Program; |
| import org.apache.sysds.runtime.controlprogram.ProgramBlock; |
| import org.apache.sysds.runtime.controlprogram.WhileProgramBlock; |
| import org.apache.sysds.runtime.controlprogram.caching.CacheableData.CacheStatus; |
| import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; |
| import org.apache.sysds.runtime.instructions.Instruction; |
| import org.apache.sysds.runtime.instructions.InstructionUtils; |
| import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.CPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.Data; |
| import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.StringInitCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction; |
| import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; |
| import org.apache.sysds.runtime.matrix.operators.CMOperator; |
| import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes; |
| import org.apache.sysds.runtime.meta.DataCharacteristics; |
| |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| |
| public abstract class CostEstimator |
| { |
| |
| protected static final Log LOG = LogFactory.getLog(CostEstimator.class.getName()); |
| |
| private static final int DEFAULT_NUMITER = 15; |
| |
| protected static final VarStats _unknownStats = new VarStats(1,1,-1,-1,false); |
| protected static final VarStats _scalarStats = new VarStats(1,1,1,1,true); |
| |
| public double getTimeEstimate(Program rtprog, LocalVariableMap vars, HashMap<String,VarStats> stats) { |
| double costs = 0; |
| |
| //obtain stats from symboltable (e.g., during recompile) |
| maintainVariableStatistics(vars, stats); |
| |
| //get cost estimate |
| for( ProgramBlock pb : rtprog.getProgramBlocks() ) |
| costs += rGetTimeEstimate(pb, stats, new HashSet<String>(), true); |
| |
| return costs; |
| } |
| |
| public double getTimeEstimate(ProgramBlock pb, LocalVariableMap vars, HashMap<String,VarStats> stats, boolean recursive) { |
| //obtain stats from symboltable (e.g., during recompile) |
| maintainVariableStatistics(vars, stats); |
| |
| //get cost estimate |
| return rGetTimeEstimate(pb, stats, new HashSet<String>(), recursive); |
| } |
| |
| private double rGetTimeEstimate(ProgramBlock pb, HashMap<String,VarStats> stats, HashSet<String> memoFunc, boolean recursive) { |
| double ret = 0; |
| |
| if (pb instanceof WhileProgramBlock) { |
| WhileProgramBlock tmp = (WhileProgramBlock)pb; |
| if( recursive ) |
| for (ProgramBlock pb2 : tmp.getChildBlocks()) |
| ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive); |
| ret *= DEFAULT_NUMITER; |
| } |
| else if (pb instanceof IfProgramBlock) { |
| IfProgramBlock tmp = (IfProgramBlock)pb; |
| if( recursive ) { |
| for( ProgramBlock pb2 : tmp.getChildBlocksIfBody() ) |
| ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive); |
| if( tmp.getChildBlocksElseBody()!=null ) |
| for( ProgramBlock pb2 : tmp.getChildBlocksElseBody() ){ |
| ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive); |
| ret /= 2; //weighted sum |
| } |
| } |
| } |
| else if (pb instanceof ForProgramBlock) { //includes ParFORProgramBlock |
| ForProgramBlock tmp = (ForProgramBlock)pb; |
| if( recursive ) |
| for( ProgramBlock pb2 : tmp.getChildBlocks() ) |
| ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive); |
| |
| ret *= getNumIterations(stats, tmp); |
| } |
| else if ( pb instanceof FunctionProgramBlock ) { |
| FunctionProgramBlock tmp = (FunctionProgramBlock) pb; |
| if( recursive ) |
| for( ProgramBlock pb2 : tmp.getChildBlocks() ) |
| ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive); |
| } |
| else if( pb instanceof BasicProgramBlock ) |
| { |
| BasicProgramBlock bpb = (BasicProgramBlock) pb; |
| ArrayList<Instruction> tmp = bpb.getInstructions(); |
| |
| for( Instruction inst : tmp ) |
| { |
| if( inst instanceof CPInstruction ) //CP |
| { |
| //obtain stats from createvar, cpvar, rmvar, rand |
| maintainCPInstVariableStatistics((CPInstruction)inst, stats); |
| |
| //extract statistics (instruction-specific) |
| Object[] o = extractCPInstStatistics(inst, stats); |
| VarStats[] vs = (VarStats[]) o[0]; |
| String[] attr = (String[]) o[1]; |
| |
| //if(LOG.isDebugEnabled()) |
| // LOG.debug(inst); |
| |
| //call time estimation for inst |
| ret += getCPInstTimeEstimate(inst, vs, attr); |
| |
| if( inst instanceof FunctionCallCPInstruction ) //functions |
| { |
| FunctionCallCPInstruction finst = (FunctionCallCPInstruction)inst; |
| String fkey = DMLProgram.constructFunctionKey(finst.getNamespace(), finst.getFunctionName()); |
| //awareness of recursive functions, missing program |
| if( !memoFunc.contains(fkey) && pb.getProgram()!=null ) |
| { |
| if(LOG.isDebugEnabled()) |
| LOG.debug("Begin Function "+fkey); |
| |
| memoFunc.add(fkey); |
| Program prog = pb.getProgram(); |
| FunctionProgramBlock fpb = prog.getFunctionProgramBlock( |
| finst.getNamespace(), finst.getFunctionName()); |
| ret += rGetTimeEstimate(fpb, stats, memoFunc, recursive); |
| memoFunc.remove(fkey); |
| |
| if(LOG.isDebugEnabled()) |
| LOG.debug("End Function "+fkey); |
| } |
| } |
| } |
| } |
| } |
| |
| return ret; |
| } |
| |
| private static void maintainVariableStatistics( LocalVariableMap vars, HashMap<String, VarStats> stats ) { |
| for( String varname : vars.keySet() ) |
| { |
| Data dat = vars.get(varname); |
| VarStats vs = null; |
| if( dat instanceof MatrixObject ) //matrix |
| { |
| MatrixObject mo = (MatrixObject) dat; |
| DataCharacteristics dc = mo.getDataCharacteristics(); |
| long rlen = dc.getRows(); |
| long clen = dc.getCols(); |
| int blen = dc.getBlocksize(); |
| long nnz = dc.getNonZeros(); |
| boolean inmem = mo.getStatus()==CacheStatus.CACHED; |
| vs = new VarStats(rlen, clen, blen, nnz, inmem); |
| } |
| else //scalar |
| { |
| vs = _scalarStats; |
| } |
| |
| stats.put(varname, vs); |
| } |
| } |
| |
| private static void maintainCPInstVariableStatistics( CPInstruction inst, HashMap<String, VarStats> stats ) |
| { |
| if( inst instanceof VariableCPInstruction ) |
| { |
| String optype = inst.getOpcode(); |
| String[] parts = InstructionUtils.getInstructionParts(inst.toString()); |
| |
| if( optype.equals("createvar") ) { |
| if( parts.length<10 ) |
| return; |
| String varname = parts[1]; |
| long rlen = Long.parseLong(parts[6]); |
| long clen = Long.parseLong(parts[7]); |
| int blen = Integer.parseInt(parts[8]); |
| long nnz = Long.parseLong(parts[10]); |
| VarStats vs = new VarStats(rlen, clen, blen, nnz, false); |
| stats.put(varname, vs); |
| } |
| else if ( optype.equals("cpvar") ) { |
| String varname = parts[1]; |
| String varname2 = parts[2]; |
| VarStats vs = stats.get(varname); |
| stats.put(varname2, vs); |
| } |
| else if ( optype.equals("mvvar") ) { |
| String varname = parts[1]; |
| String varname2 = parts[2]; |
| VarStats vs = stats.remove(varname); |
| stats.put(varname2, vs); |
| } |
| else if( optype.equals("rmvar") ) { |
| String varname = parts[1]; |
| stats.remove(varname); |
| } |
| } |
| else if( inst instanceof DataGenCPInstruction ){ |
| DataGenCPInstruction randInst = (DataGenCPInstruction) inst; |
| String varname = randInst.output.getName(); |
| long rlen = randInst.getRows(); |
| long clen = randInst.getCols(); |
| int blen = randInst.getBlocksize(); |
| long nnz = (long) (randInst.getSparsity() * rlen * clen); |
| VarStats vs = new VarStats(rlen, clen, blen, nnz, true); |
| stats.put(varname, vs); |
| } |
| else if( inst instanceof StringInitCPInstruction ){ |
| StringInitCPInstruction iinst = (StringInitCPInstruction) inst; |
| String varname = iinst.output.getName(); |
| long rlen = iinst.getRows(); |
| long clen = iinst.getCols(); |
| VarStats vs = new VarStats(rlen, clen, ConfigurationManager.getBlocksize(), rlen*clen, true); |
| stats.put(varname, vs); |
| } |
| else if( inst instanceof FunctionCallCPInstruction ) |
| { |
| FunctionCallCPInstruction finst = (FunctionCallCPInstruction) inst; |
| for( String varname : finst.getBoundOutputParamNames() ) |
| stats.put(varname, _unknownStats); |
| } |
| } |
| |
| protected String replaceInstructionPatch( String inst ) |
| { |
| String ret = inst; |
| while( ret.contains(Lop.VARIABLE_NAME_PLACEHOLDER) ) { |
| int index1 = ret.indexOf(Lop.VARIABLE_NAME_PLACEHOLDER); |
| int index2 = ret.indexOf(Lop.VARIABLE_NAME_PLACEHOLDER, index1+1); |
| String replace = ret.substring(index1,index2+1); |
| ret = ret.replaceAll(replace, "1"); |
| } |
| |
| return ret; |
| } |
| |
| private static Object[] extractCPInstStatistics( Instruction inst, HashMap<String, VarStats> stats ) |
| { |
| Object[] ret = new Object[2]; //stats, attrs |
| VarStats[] vs = new VarStats[3]; |
| String[] attr = null; |
| |
| if( inst instanceof UnaryCPInstruction ) |
| { |
| if( inst instanceof DataGenCPInstruction ) |
| { |
| DataGenCPInstruction rinst = (DataGenCPInstruction) inst; |
| vs[0] = _unknownStats; |
| vs[1] = _unknownStats; |
| vs[2] = stats.get( rinst.output.getName() ); |
| |
| //prepare attributes for cost estimation |
| int type = 2; //full rand |
| if( rinst.getMinValue() == 0.0 && rinst.getMaxValue() == 0.0 ) |
| type = 0; |
| else if( rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue() ) |
| type = 1; |
| attr = new String[]{String.valueOf(type)}; |
| } |
| else if( inst instanceof StringInitCPInstruction ) |
| { |
| StringInitCPInstruction rinst = (StringInitCPInstruction) inst; |
| vs[0] = _unknownStats; |
| vs[1] = _unknownStats; |
| vs[2] = stats.get( rinst.output.getName() ); |
| } |
| else //general unary |
| { |
| UnaryCPInstruction uinst = (UnaryCPInstruction) inst; |
| vs[0] = stats.get( uinst.input1.getName() ); |
| vs[1] = _unknownStats; |
| vs[2] = stats.get( uinst.output.getName() ); |
| |
| if( vs[0] == null ) //scalar input, e.g., print |
| vs[0] = _scalarStats; |
| if( vs[2] == null ) //scalar output |
| vs[2] = _scalarStats; |
| |
| if( inst instanceof MMTSJCPInstruction ) |
| { |
| String type = ((MMTSJCPInstruction)inst).getMMTSJType().toString(); |
| attr = new String[]{type}; |
| } |
| else if( inst instanceof AggregateUnaryCPInstruction ) |
| { |
| String[] parts = InstructionUtils.getInstructionParts(inst.toString()); |
| String opcode = parts[0]; |
| if( opcode.equals("cm") ) |
| attr = new String[]{parts[parts.length-2]}; |
| } |
| } |
| } |
| else if( inst instanceof BinaryCPInstruction ) |
| { |
| BinaryCPInstruction binst = (BinaryCPInstruction) inst; |
| vs[0] = stats.get( binst.input1.getName() ); |
| vs[1] = stats.get( binst.input2.getName() ); |
| vs[2] = stats.get( binst.output.getName() ); |
| |
| |
| if( vs[0] == null ) //scalar input, |
| vs[0] = _scalarStats; |
| if( vs[1] == null ) //scalar input, |
| vs[1] = _scalarStats; |
| if( vs[2] == null ) //scalar output |
| vs[2] = _scalarStats; |
| } |
| else if( inst instanceof AggregateTernaryCPInstruction ) |
| { |
| AggregateTernaryCPInstruction binst = (AggregateTernaryCPInstruction) inst; |
| //of same dimension anyway but missing third input |
| vs[0] = stats.get( binst.input1.getName() ); |
| vs[1] = stats.get( binst.input2.getName() ); |
| vs[2] = stats.get( binst.output.getName() ); |
| |
| if( vs[0] == null ) //scalar input, |
| vs[0] = _scalarStats; |
| if( vs[1] == null ) //scalar input, |
| vs[1] = _scalarStats; |
| if( vs[2] == null ) //scalar output |
| vs[2] = _scalarStats; |
| } |
| else if( inst instanceof ParameterizedBuiltinCPInstruction ) |
| { |
| //ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst; |
| String[] parts = InstructionUtils.getInstructionParts(inst.toString()); |
| String opcode = parts[0]; |
| if( opcode.equals("groupedagg") ) |
| { |
| HashMap<String,String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts); |
| String fn = paramsMap.get("fn"); |
| String order = paramsMap.get("order"); |
| AggregateOperationTypes type = CMOperator.getAggOpType(fn, order); |
| attr = new String[]{String.valueOf(type.ordinal())}; |
| } |
| else if( opcode.equals("rmempty") ) |
| { |
| HashMap<String,String> paramsMap = ParameterizedBuiltinCPInstruction.constructParameterMap(parts); |
| attr = new String[]{String.valueOf(paramsMap.get("margin").equals("rows")?0:1)}; |
| } |
| |
| vs[0] = stats.get( parts[1].substring(7).replaceAll(Lop.VARIABLE_NAME_PLACEHOLDER, "") ); |
| vs[1] = _unknownStats; //TODO |
| vs[2] = stats.get( parts[parts.length-1] ); |
| |
| if( vs[0] == null ) //scalar input |
| vs[0] = _scalarStats; |
| if( vs[2] == null ) //scalar output |
| vs[2] = _scalarStats; |
| } |
| else if( inst instanceof MultiReturnBuiltinCPInstruction ) |
| { |
| //applies to qr, lu, eigen (cost computation on input1) |
| MultiReturnBuiltinCPInstruction minst = (MultiReturnBuiltinCPInstruction) inst; |
| vs[0] = stats.get( minst.input1.getName() ); |
| vs[1] = stats.get( minst.getOutput(0).getName() ); |
| vs[2] = stats.get( minst.getOutput(1).getName() ); |
| } |
| else if( inst instanceof VariableCPInstruction ) |
| { |
| setUnknownStats(vs); |
| |
| VariableCPInstruction varinst = (VariableCPInstruction) inst; |
| if( varinst.getOpcode().equals("write") ) { |
| //special handling write of matrix objects (non existing if scalar) |
| if( stats.containsKey( varinst.getInput1().getName() ) ) |
| vs[0] = stats.get( varinst.getInput1().getName() ); |
| attr = new String[]{varinst.getInput3().getName()}; |
| } |
| } |
| else |
| { |
| setUnknownStats(vs); |
| } |
| |
| //maintain var status (CP output always inmem) |
| vs[2]._inmem = true; |
| |
| ret[0] = vs; |
| ret[1] = attr; |
| |
| return ret; |
| } |
| |
| private static void setUnknownStats(VarStats[] vs) { |
| vs[0] = _unknownStats; |
| vs[1] = _unknownStats; |
| vs[2] = _unknownStats; |
| } |
| |
| private static long getNumIterations(HashMap<String,VarStats> stats, ForProgramBlock pb) { |
| return OptimizerUtils.getNumIterations(pb, DEFAULT_NUMITER); |
| } |
| |
| protected abstract double getCPInstTimeEstimate( Instruction inst, VarStats[] vs, String[] args ); |
| } |