| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops.codegen.opt; |
| |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collection; |
| import java.util.Collections; |
| import java.util.Comparator; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.Iterator; |
| import java.util.LinkedHashMap; |
| import java.util.List; |
| import java.util.Map.Entry; |
| import java.util.stream.Collectors; |
| |
| import org.apache.commons.lang3.ArrayUtils; |
| import org.apache.commons.lang3.tuple.Pair; |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.sysds.api.DMLScript; |
| import org.apache.sysds.common.Types.AggOp; |
| import org.apache.sysds.common.Types.Direction; |
| import org.apache.sysds.common.Types.ExecMode; |
| import org.apache.sysds.common.Types.OpOp2; |
| import org.apache.sysds.common.Types.OpOpDG; |
| import org.apache.sysds.common.Types.OpOpData; |
| import org.apache.sysds.common.Types.OpOpN; |
| import org.apache.sysds.hops.AggBinaryOp; |
| import org.apache.sysds.hops.AggUnaryOp; |
| import org.apache.sysds.hops.BinaryOp; |
| import org.apache.sysds.hops.DnnOp; |
| import org.apache.sysds.hops.Hop; |
| import org.apache.sysds.hops.IndexingOp; |
| import org.apache.sysds.hops.LiteralOp; |
| import org.apache.sysds.hops.NaryOp; |
| import org.apache.sysds.hops.OptimizerUtils; |
| import org.apache.sysds.hops.ParameterizedBuiltinOp; |
| import org.apache.sysds.hops.ReorgOp; |
| import org.apache.sysds.hops.TernaryOp; |
| import org.apache.sysds.hops.UnaryOp; |
| import org.apache.sysds.hops.codegen.opt.ReachabilityGraph.SubProblem; |
| import org.apache.sysds.hops.codegen.template.CPlanMemoTable; |
| import org.apache.sysds.hops.codegen.template.TemplateOuterProduct; |
| import org.apache.sysds.hops.codegen.template.TemplateRow; |
| import org.apache.sysds.hops.codegen.template.TemplateUtils; |
| import org.apache.sysds.hops.codegen.template.CPlanMemoTable.MemoTableEntry; |
| import org.apache.sysds.hops.codegen.template.TemplateBase.TemplateType; |
| import org.apache.sysds.hops.rewrite.HopRewriteUtils; |
| import org.apache.sysds.runtime.codegen.LibSpoofPrimitives; |
| import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer; |
| import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; |
| import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; |
| import org.apache.sysds.runtime.util.CollectionUtils; |
| import org.apache.sysds.runtime.util.UtilFunctions; |
| import org.apache.sysds.utils.Statistics; |
| |
| /** |
| * This cost-based plan selection algorithm chooses fused operators |
| * based on the DAG structure and resulting overall costs. This includes |
| * holistic decisions on |
| * <ul> |
| * <li>Materialization points per consumer</li> |
| * <li>Sparsity exploitation and operator ordering</li> |
| * <li>Decisions on overlapping template types</li> |
| * <li>Decisions on multi-aggregates with shared reads</li> |
| * <li>Constraints (e.g., memory budgets and block sizes)</li> |
| * </ul> |
| * |
| */ |
| public class PlanSelectionFuseCostBasedV2 extends PlanSelection |
| { |
| private static final Log LOG = LogFactory.getLog(PlanSelectionFuseCostBasedV2.class.getName()); |
| |
| //common bandwidth characteristics, with a conservative write bandwidth in order |
| //to cover result allocation, write into main memory, and potential evictions |
| private static final double WRITE_BANDWIDTH_IO = 512*1024*1024; //512MB/s |
| private static final double WRITE_BANDWIDTH_MEM = 2d*1024*1024*1024; //2GB/s |
| private static final double READ_BANDWIDTH_MEM = 32d*1024*1024*1024; //32GB/s |
| private static final double READ_BANDWIDTH_BROADCAST = WRITE_BANDWIDTH_IO/4; |
| private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //1GFLOPs/core |
| * InfrastructureAnalyzer.getLocalParallelism(); |
| |
| //sparsity estimate for unknown sparsity to prefer sparse-safe fusion plans |
| private static final double SPARSE_SAFE_SPARSITY_EST = 0.1; |
| |
| //after evaluating the costs of the opening heuristics fuse-all and fuse-no-redundancy, |
| //remaining candidate plans of large partitions (w/ >= COST_MIN_EPS_NUM_POINTS) are |
| //only evaluated if the current costs are > (1+COST_MIN_EPS) * static (i.e., minimal) costs. |
| public static final double COST_MIN_EPS = 0.01; //1% |
| public static final int COST_MIN_EPS_NUM_POINTS = 20; //2^20 = 1M plans |
| |
| //In order to avoid unnecessary repeated reoptimization we use a plan cache for |
| //mapping partition signatures (including input sizes) to optimal plans. However, |
| //since hop ids change during dynamic recompilation, we use an approximate signature |
| //that is cheap to compute and therefore only use this for large partitions. |
| private static final int PLAN_CACHE_NUM_POINTS = 10; //2^10 = 1024 |
| private static final int PLAN_CACHE_SIZE = 1024; |
| private static final LinkedHashMap<PartitionSignature, boolean[]> _planCache = new LinkedHashMap<>(); |
| |
| //optimizer configuration |
| public static boolean COST_PRUNING = true; |
| public static boolean STRUCTURAL_PRUNING = true; |
| public static boolean PLAN_CACHING = true; |
| private static final TemplateRow ROW_TPL = new TemplateRow(); |
| |
| //cost vector id generator, whose ids are only used for memoization per call to getPlanCost; |
| //hence, we use a sequence generator per optimizer instance to avoid thread contention in |
| //multi-threaded parfor scenarios with concurrent dynamic recompilation and thus optimization. |
| private final IDSequence COST_ID = new IDSequence(); |
| |
| @Override |
| public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) |
| { |
| //step 1: analyze connected partitions (nodes, roots, mat points) |
| Collection<PlanPartition> parts = PlanAnalyzer.analyzePlanPartitions(memo, roots, true); |
| |
| //step 2: optimize individual plan partitions |
| int sumMatPoints = 0; |
| for( PlanPartition part : parts ) { |
| //create composite templates (within the partition) |
| createAndAddMultiAggPlans(memo, part.getPartition(), part.getRoots()); |
| |
| //plan enumeration and plan selection |
| selectPlans(memo, part); |
| sumMatPoints += part.getMatPointsExt().length; |
| } |
| |
| //step 3: add composite templates (across partitions) |
| createAndAddMultiAggPlans(memo, roots); |
| |
| //take all distinct best plans |
| for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) |
| memo.setDistinct(e.getKey(), e.getValue()); |
| |
| //maintain statistics |
| if( DMLScript.STATISTICS ) { |
| if( sumMatPoints >= 63 ) |
| LOG.warn("Long overflow on maintaining codegen statistics " |
| + "for a DAG with "+sumMatPoints+" interesting points."); |
| Statistics.incrementCodegenEnumAll(UtilFunctions.pow(2, sumMatPoints)); |
| } |
| } |
| |
| private void selectPlans(CPlanMemoTable memo, PlanPartition part) |
| { |
| //prune special case patterns and invalid plans (e.g., blocksize) |
| pruneInvalidAndSpecialCasePlans(memo, part); |
| |
| //if no materialization points, use basic fuse-all w/ partition awareness |
| if( part.getMatPointsExt() == null || part.getMatPointsExt().length==0 ) { |
| for( Long hopID : part.getRoots() ) |
| rSelectPlansFuseAll(memo, |
| memo.getHopRefs().get(hopID), null, part.getPartition()); |
| } |
| else { |
| //obtain hop compute costs per cell once |
| HashMap<Long, Double> computeCosts = new HashMap<>(); |
| for( Long hopID : part.getPartition() ) |
| getComputeCosts(memo.getHopRefs().get(hopID), computeCosts); |
| |
| //prepare pruning helpers and prune memo table w/ determined mat points |
| StaticCosts costs = new StaticCosts(computeCosts, sumComputeCost(computeCosts), |
| getReadCost(part, memo), getWriteCost(part.getRoots(), memo), minOuterSparsity(part, memo)); |
| ReachabilityGraph rgraph = STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null; |
| if( STRUCTURAL_PRUNING ) { |
| part.setMatPointsExt(rgraph.getSortedSearchSpace()); |
| for( Long hopID : part.getPartition() ) |
| memo.pruneRedundant(hopID, true, part.getMatPointsExt()); |
| } |
| |
| //enumerate and cost plans, returns optional plan |
| boolean[] bestPlan = enumPlans(memo, part, |
| costs, rgraph, part.getMatPointsExt(), 0); |
| |
| //prune memo table wrt best plan and select plans |
| HashSet<Long> visited = new HashSet<>(); |
| for( Long hopID : part.getRoots() ) |
| rPruneSuboptimalPlans(memo, memo.getHopRefs().get(hopID), |
| visited, part, part.getMatPointsExt(), bestPlan); |
| HashSet<Long> visited2 = new HashSet<>(); |
| for( Long hopID : part.getRoots() ) |
| rPruneInvalidPlans(memo, memo.getHopRefs().get(hopID), |
| visited2, part, bestPlan); |
| |
| for( Long hopID : part.getRoots() ) |
| rSelectPlansFuseAll(memo, |
| memo.getHopRefs().get(hopID), null, part.getPartition()); |
| } |
| } |
| |
| /** |
| * Core plan enumeration algorithm, invoked recursively for conditionally independent |
| * subproblems. This algorithm fully explores the exponential search space of 2^m, |
| * where m is the number of interesting materialization points. We iterate over |
| * a linearized search space without every instantiating the search tree. Furthermore, |
| * in order to reduce the enumeration overhead, we apply two high-impact pruning |
| * techniques (1) pruning by evolving lower/upper cost bounds, and (2) pruning by |
| * conditional structural properties (so-called cutsets of interesting points). |
| * |
| * @param memo memoization table of partial fusion plans |
| * @param part connected component (partition) of partial fusion plans with all necessary meta data |
| * @param costs summary of static costs (e.g., partition reads, writes, and compute costs per operator) |
| * @param rgraph reachability graph of interesting materialization points |
| * @param matPoints sorted materialization points (defined the search space) |
| * @param off offset for recursive invocation, indicating the fixed plan part |
| * @return optimal assignment of materialization points |
| */ |
| private boolean[] enumPlans(CPlanMemoTable memo, PlanPartition part, StaticCosts costs, |
| ReachabilityGraph rgraph, InterestingPoint[] matPoints, int off) |
| { |
| //scan linearized search space, w/ skips for branch and bound pruning |
| //and structural pruning (where we solve conditionally independent problems) |
| //bestC is monotonically non-increasing and serves as the upper bound |
| final int Mlen = matPoints.length-off; |
| final long len = UtilFunctions.pow(2, Mlen); |
| long numEvalPlans = 2, numEvalPartPlans = 0; |
| |
| //evaluate heuristics fuse-all and fuse-no-redundancy to quickly obtain a good lower bound |
| final boolean[] plan0 = createAssignment(Mlen, off, 0); // fuse-all |
| final boolean[] planN = createAssignment(Mlen, off, len-1); //fuse-no-redundancy |
| final double C0 = getPlanCost(memo, part, matPoints, plan0, costs._computeCosts, Double.MAX_VALUE); |
| final double CN = getPlanCost(memo, part, matPoints, planN, costs._computeCosts, Double.MAX_VALUE); |
| boolean[] bestPlan = (C0 <= CN) ? plan0 : planN; |
| double bestC = Math.min(C0, CN); |
| final boolean evalRemain = (Mlen < COST_MIN_EPS_NUM_POINTS |
| || !COST_PRUNING || bestC > (1+COST_MIN_EPS) * costs.getMinCosts()); |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Enum opening: " + Arrays.toString(bestPlan) + " -> " + bestC); |
| if( !evalRemain ) |
| LOG.warn("Skip enum for |M|="+Mlen+", C="+bestC+", Cmin="+costs.getMinCosts()); |
| |
| //probe plan cache for existing optimized plan |
| PartitionSignature pKey = null; |
| if( probePlanCache(matPoints) ) { |
| pKey = new PartitionSignature(part, matPoints.length, costs, C0, CN); |
| boolean[] plan = getPlan(pKey); |
| if( plan != null ) { |
| Statistics.incrementCodegenEnumAllP((rgraph!=null||!STRUCTURAL_PRUNING)?len:0); |
| return plan; |
| } |
| } |
| |
| //evaluate remaining plans, except already evaluated heuristics |
| for( long i=1; i<len-1 & evalRemain; i++ ) { |
| //construct assignment |
| boolean[] plan = createAssignment(Mlen, off, i); |
| long pskip = 0; //skip after costing |
| |
| //skip plans with structural pruning |
| if( STRUCTURAL_PRUNING && (rgraph!=null) && rgraph.isCutSet(plan) ) { |
| //compute skip (which also acts as boundary for subproblems) |
| pskip = rgraph.getNumSkipPlans(plan); |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Enum: Structural pruning for cut set: "+rgraph.getCutSet(plan)); |
| |
| //start increment rgraph get subproblems |
| SubProblem[] prob = rgraph.getSubproblems(plan); |
| |
| //solve subproblems independently and combine into best plan |
| for( int j=0; j<prob.length; j++ ) { |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Enum: Subproblem "+(j+1)+"/"+prob.length+": "+prob[j]); |
| boolean[] bestTmp = enumPlans(memo, part, |
| costs, null, prob[j].freeMat, prob[j].offset); |
| LibSpoofPrimitives.vectWrite(bestTmp, plan, prob[j].freePos); |
| } |
| |
| //note: the overall plan costs are evaluated in full, which reused |
| //the default code path; hence we postpone the skip after costing |
| } |
| //skip plans with branch and bound pruning (cost) |
| else if( COST_PRUNING ) { |
| double lbC = getLowerBoundCosts(part, matPoints, memo, costs, plan); |
| if( lbC >= bestC ) { |
| long skip = getNumSkipPlans(plan); |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Enum: Skip "+skip+" plans (by cost)."); |
| i += skip - 1; |
| continue; |
| } |
| } |
| |
| //cost assignment on hops. Stop early if exceeds bestC. |
| double pCBound = COST_PRUNING ? bestC : Double.MAX_VALUE; |
| double C = getPlanCost(memo, part, matPoints, plan, costs._computeCosts, pCBound); |
| if (LOG.isTraceEnabled()) |
| LOG.trace("Enum: " + Arrays.toString(plan) + " -> " + C); |
| numEvalPartPlans += (C==Double.POSITIVE_INFINITY) ? 1 : 0; |
| numEvalPlans++; |
| |
| //cost comparisons |
| if( bestPlan == null || C < bestC ) { |
| bestC = C; |
| bestPlan = plan; |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Enum: Found new best plan."); |
| } |
| |
| //post skipping |
| i += pskip; |
| if( pskip !=0 && LOG.isTraceEnabled() ) |
| LOG.trace("Enum: Skip "+pskip+" plans (by structure)."); |
| } |
| |
| if( DMLScript.STATISTICS ) { |
| Statistics.incrementCodegenEnumAllP((rgraph!=null||!STRUCTURAL_PRUNING)?len:0); |
| Statistics.incrementCodegenEnumEval(numEvalPlans); |
| Statistics.incrementCodegenEnumEvalP(numEvalPartPlans); |
| } |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Enum: Optimal plan: "+Arrays.toString(bestPlan)); |
| |
| //keep large plans |
| if( probePlanCache(matPoints) ) |
| putPlan(pKey, bestPlan); |
| |
| //copy best plan w/o fixed offset plan |
| return (bestPlan==null) ? new boolean[Mlen] : |
| Arrays.copyOfRange(bestPlan, off, bestPlan.length); |
| } |
| |
| private static boolean[] createAssignment(int len, int off, long pos) { |
| boolean[] ret = new boolean[off+len]; |
| Arrays.fill(ret, 0, off, true); |
| long tmp = pos; |
| for( int i=0; i<len; i++ ) { |
| long mask = UtilFunctions.pow(2, len-i-1); |
| ret[off+i] = tmp >= mask; |
| tmp %= mask; |
| } |
| return ret; |
| } |
| |
| private static long getNumSkipPlans(boolean[] plan) { |
| int pos = ArrayUtils.lastIndexOf(plan, true); |
| return UtilFunctions.pow(2, plan.length-pos-1); |
| } |
| |
| private static double getLowerBoundCosts(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo, StaticCosts costs, boolean[] plan) { |
| //compute the lower bound from static and plan-dependent costs |
| double lb = Math.max(costs._read, costs._compute) + costs._write |
| + getMaterializationCost(part, M, memo, plan); |
| |
| //if the partition contains outer templates, we need to correct the lower bound |
| if( part.hasOuter() ) |
| lb *= costs._minSparsity; |
| |
| return lb; |
| } |
| |
| private static double getMaterializationCost(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo, boolean[] plan) { |
| double costs = 0; |
| //currently active materialization points |
| HashSet<Long> matTargets = new HashSet<>(); |
| for( int i=0; i<plan.length; i++ ) { |
| long hopID = M[i].getToHopID(); |
| if( plan[i] && !matTargets.contains(hopID) ) { |
| matTargets.add(hopID); |
| Hop hop = memo.getHopRefs().get(hopID); |
| long size = getSize(hop); |
| costs += size * 8 / WRITE_BANDWIDTH_MEM + |
| size * 8 / READ_BANDWIDTH_MEM; |
| } |
| } |
| //points with non-partition consumers |
| for( Long hopID : part.getExtConsumed() ) |
| if( !matTargets.contains(hopID) ) { |
| matTargets.add(hopID); |
| Hop hop = memo.getHopRefs().get(hopID); |
| costs += getSize(hop) * 8 / WRITE_BANDWIDTH_MEM; |
| } |
| |
| return costs; |
| } |
| |
| private static double getReadCost(PlanPartition part, CPlanMemoTable memo) { |
| double costs = 0; |
| //get partition input reads (at least read once) |
| for( Long hopID : part.getInputs() ) { |
| Hop hop = memo.getHopRefs().get(hopID); |
| costs += getSafeMemEst(hop) / READ_BANDWIDTH_MEM; |
| } |
| return costs; |
| } |
| |
| private static double getWriteCost(Collection<Long> R, CPlanMemoTable memo) { |
| double costs = 0; |
| for( Long hopID : R ) { |
| Hop hop = memo.getHopRefs().get(hopID); |
| costs += getSize(hop) * 8 / WRITE_BANDWIDTH_MEM; |
| } |
| return costs; |
| } |
| |
| private static double sumComputeCost(HashMap<Long, Double> computeCosts) { |
| return computeCosts.values().stream() |
| .mapToDouble(d -> d/COMPUTE_BANDWIDTH).sum(); |
| } |
| |
| private static double minOuterSparsity(PlanPartition part, CPlanMemoTable memo) { |
| return !part.hasOuter() ? 1.0 : part.getPartition().stream() |
| .map(k -> HopRewriteUtils.getLargestInput(memo.getHopRefs().get(k))) |
| .mapToDouble(h -> h.dimsKnown(true) ? h.getSparsity() : SPARSE_SAFE_SPARSITY_EST) |
| .min().orElse(SPARSE_SAFE_SPARSITY_EST); |
| } |
| |
| private static double sumTmpInputOutputSize(CPlanMemoTable memo, CostVector vect) { |
| //size of intermediate inputs and outputs, i.e., output and inputs other than treads |
| return vect.outSize + vect.inSizes.entrySet().stream() |
| .filter(e -> !HopRewriteUtils.isData(memo.getHopRefs().get(e.getKey()), OpOpData.TRANSIENTREAD)) |
| .mapToDouble(e -> e.getValue()).sum(); |
| } |
| |
| private static double sumInputMemoryEstimates(CPlanMemoTable memo, CostVector vect) { |
| return vect.inSizes.keySet().stream() |
| .mapToDouble(e -> getSafeMemEst(memo.getHopRefs().get(e))).sum(); |
| } |
| |
| private static double getSafeMemEst(Hop hop) { |
| return !hop.dimsKnown() ? getSize(hop) * 8 |
| : hop.getOutputMemEstimate(); |
| } |
| |
| private static long getSize(Hop hop) { |
| return Math.max(hop.getDim1(),1) |
| * Math.max(hop.getDim2(),1); |
| } |
| |
| //within-partition multi-agg templates |
| private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) |
| { |
| //create index of plans that reference full aggregates to avoid circular dependencies |
| HashSet<Long> refHops = new HashSet<>(); |
| for( Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet() ) |
| if( !e.getValue().isEmpty() ) { |
| Hop hop = memo.getHopRefs().get(e.getKey()); |
| for( Hop c : hop.getInput() ) |
| refHops.add(c.getHopID()); |
| } |
| |
| //find all full aggregations (the fact that they are in the same partition guarantees |
| //that they also have common subexpressions, also full aggregations are by def root nodes) |
| ArrayList<Long> fullAggs = new ArrayList<>(); |
| for( Long hopID : R ) { |
| Hop root = memo.getHopRefs().get(hopID); |
| if( !refHops.contains(hopID) && isMultiAggregateRoot(root) ) |
| fullAggs.add(hopID); |
| } |
| if( LOG.isTraceEnabled() ) { |
| LOG.trace("Found within-partition ua(RC) aggregations: " + |
| Arrays.toString(fullAggs.toArray(new Long[0]))); |
| } |
| |
| //construct and add multiagg template plans (w/ max 3 aggregations) |
| for( int i=0; i<fullAggs.size(); i+=3 ) { |
| int ito = Math.min(i+3, fullAggs.size()); |
| if( ito-i >= 2 ) { |
| MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, |
| fullAggs.get(i), fullAggs.get(i+1), ((ito-i)==3)?fullAggs.get(i+2):-1, ito-i); |
| if( isValidMultiAggregate(memo, me) ) { |
| for( int j=i; j<ito; j++ ) { |
| memo.add(memo.getHopRefs().get(fullAggs.get(j)), me); |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Added multiagg plan: "+fullAggs.get(j)+" "+me); |
| } |
| } |
| else if( LOG.isTraceEnabled() ) { |
| LOG.trace("Removed invalid multiagg plan: "+me); |
| } |
| } |
| } |
| } |
| |
| //across-partition multi-agg templates with shared reads |
| private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) |
| { |
| //collect full aggregations as initial set of candidates |
| HashSet<Long> fullAggs = new HashSet<>(); |
| Hop.resetVisitStatus(roots); |
| for( Hop hop : roots ) |
| rCollectFullAggregates(hop, fullAggs); |
| Hop.resetVisitStatus(roots); |
| |
| //remove operators with assigned multi-agg plans |
| fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG)); |
| |
| //check applicability for further analysis |
| if( fullAggs.size() <= 1 ) |
| return; |
| |
| if( LOG.isTraceEnabled() ) { |
| LOG.trace("Found across-partition ua(RC) aggregations: " + |
| Arrays.toString(fullAggs.toArray(new Long[0]))); |
| } |
| |
| //collect information for all candidates |
| //(subsumed aggregations, and inputs to fused operators) |
| List<AggregateInfo> aggInfos = new ArrayList<>(); |
| for( Long hopID : fullAggs ) { |
| Hop aggHop = memo.getHopRefs().get(hopID); |
| AggregateInfo tmp = new AggregateInfo(aggHop); |
| for( int i=0; i<aggHop.getInput().size(); i++ ) { |
| Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ? |
| aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i); |
| rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL); |
| } |
| if( tmp._fusedInputs.isEmpty() ) { |
| if( HopRewriteUtils.isMatrixMultiply(aggHop) ) { |
| tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID()); |
| tmp.addFusedInput(aggHop.getInput().get(1).getHopID()); |
| } |
| else |
| tmp.addFusedInput(aggHop.getInput().get(0).getHopID()); |
| } |
| aggInfos.add(tmp); |
| } |
| |
| if( LOG.isTraceEnabled() ) { |
| LOG.trace("Extracted across-partition ua(RC) aggregation info: "); |
| for( AggregateInfo info : aggInfos ) |
| LOG.trace(info); |
| } |
| |
| //sort aggregations by num dependencies to simplify merging |
| //clusters of aggregations with parallel dependencies |
| aggInfos = aggInfos.stream() |
| .sorted(Comparator.comparing(a -> a._inputAggs.size())) |
| .collect(Collectors.toList()); |
| |
| //greedy grouping of multi-agg candidates |
| boolean converged = false; |
| while( !converged ) { |
| AggregateInfo merged = null; |
| for( int i=0; i<aggInfos.size(); i++ ) { |
| AggregateInfo current = aggInfos.get(i); |
| for( int j=i+1; j<aggInfos.size(); j++ ) { |
| AggregateInfo that = aggInfos.get(j); |
| if( current.isMergable(that) ) { |
| merged = current.merge(that); |
| aggInfos.remove(j); j--; |
| } |
| } |
| } |
| converged = (merged == null); |
| } |
| |
| if( LOG.isTraceEnabled() ) { |
| LOG.trace("Merged across-partition ua(RC) aggregation info: "); |
| for( AggregateInfo info : aggInfos ) |
| LOG.trace(info); |
| } |
| |
| //construct and add multiagg template plans (w/ max 3 aggregations) |
| for( AggregateInfo info : aggInfos ) { |
| if( info._aggregates.size()<=1 ) |
| continue; |
| Long[] aggs = info._aggregates.keySet().toArray(new Long[0]); |
| MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, |
| aggs[0], aggs[1], (aggs.length>2)?aggs[2]:-1, aggs.length); |
| for( int i=0; i<aggs.length; i++ ) { |
| memo.add(memo.getHopRefs().get(aggs[i]), me); |
| addBestPlan(aggs[i], me); |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Added multiagg* plan: "+aggs[i]+" "+me); |
| |
| } |
| } |
| } |
| |
| private static boolean isMultiAggregateRoot(Hop root) { |
| return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) |
| && ((AggUnaryOp)root).getDirection()==Direction.RowCol) |
| || (root instanceof AggBinaryOp && root.getDim1()==1 && root.getDim2()==1 |
| && HopRewriteUtils.isTransposeOperation(root.getInput().get(0))); |
| } |
| |
| private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) { |
| //ensure input consistent sizes (otherwise potential for incorrect results) |
| boolean ret = true; |
| Hop refSize = memo.getHopRefs().get(me.input1).getInput().get(0); |
| for( int i=1; ret && i<3; i++ ) { |
| if( me.isPlanRef(i) ) |
| ret &= HopRewriteUtils.isEqualSize(refSize, |
| memo.getHopRefs().get(me.input(i)).getInput().get(0)); |
| } |
| |
| //ensure that aggregates are independent of each other, i.e., |
| //they to not have potentially transitive parent child references |
| for( int i=0; ret && i<3; i++ ) |
| if( me.isPlanRef(i) ) { |
| HashSet<Long> probe = new HashSet<>(); |
| for( int j=0; j<3; j++ ) |
| if( i != j ) |
| probe.add(me.input(j)); |
| ret &= rCheckMultiAggregate(memo.getHopRefs().get(me.input(i)), probe); |
| } |
| return ret; |
| } |
| |
| private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> probe) { |
| boolean ret = true; |
| for( Hop c : current.getInput() ) |
| ret &= rCheckMultiAggregate(c, probe); |
| ret &= !probe.contains(current.getHopID()); |
| return ret; |
| } |
| |
| private static void rCollectFullAggregates(Hop current, HashSet<Long> aggs) { |
| if( current.isVisited() ) |
| return; |
| |
| //collect all applicable full aggregations per read |
| if( isMultiAggregateRoot(current) ) |
| aggs.add(current.getHopID()); |
| |
| //recursively process children |
| for( Hop c : current.getInput() ) |
| rCollectFullAggregates(c, aggs); |
| |
| current.setVisited(); |
| } |
| |
| private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo, TemplateType type) { |
| //collect input aggregates (dependents) |
| if( isMultiAggregateRoot(current) ) |
| aggInfo.addInputAggregate(current.getHopID()); |
| |
| //recursively process children |
| MemoTableEntry me = (type!=null) ? memo.getBest(current.getHopID()) : null; |
| for( int i=0; i<current.getInput().size(); i++ ) { |
| Hop c = current.getInput().get(i); |
| if( me != null && me.isPlanRef(i) ) |
| rExtractAggregateInfo(memo, c, aggInfo, type); |
| else { |
| if( type != null && c.getDataType().isMatrix() ) //add fused input |
| aggInfo.addFusedInput(c.getHopID()); |
| rExtractAggregateInfo(memo, c, aggInfo, null); |
| } |
| } |
| } |
| |
| private static HashSet<Long> collectIrreplaceableRowOps(CPlanMemoTable memo, PlanPartition part) { |
| //get row entries that are (a) reachable from rowwise ops (top down) other than |
| //operator root nodes, or dependent upon row-wise ops (bottom up) |
| HashSet<Long> excludeList = new HashSet<>(); |
| HashSet<Pair<Long, Integer>> visited = new HashSet<>(); |
| for( Long hopID : part.getRoots() ) { |
| rCollectDependentRowOps(memo.getHopRefs().get(hopID), |
| memo, part, excludeList, visited, null, false); |
| } |
| return excludeList; |
| } |
| |
| private static void rCollectDependentRowOps(Hop hop, CPlanMemoTable memo, PlanPartition part, |
| HashSet<Long> excludeList, HashSet<Pair<Long, Integer>> visited, TemplateType type, boolean foundRowOp) |
| { |
| //avoid redundant evaluation of processed and non-partition nodes |
| Pair<Long, Integer> key = Pair.of(hop.getHopID(), |
| (foundRowOp?Short.MAX_VALUE:0) + ((type!=null)?type.ordinal()+1:0)); |
| if( visited.contains(key) || !part.getPartition().contains(hop.getHopID()) ) { |
| return; |
| } |
| |
| //process node itself (top-down) |
| MemoTableEntry me = (type == null) ? memo.getBest(hop.getHopID()) : |
| memo.getBest(hop.getHopID(), type); |
| boolean inRow = (me != null && me.type == TemplateType.ROW && type == TemplateType.ROW); |
| boolean diffPlans = part.getMatPointsExt().length > 0 //guard against plan differences |
| && memo.contains(hop.getHopID(), TemplateType.ROW) |
| && !memo.hasOnlyExactMatches(hop.getHopID(), TemplateType.ROW, TemplateType.CELL); |
| if( inRow && foundRowOp ) |
| excludeList.add(hop.getHopID()); |
| if( isRowAggOp(hop, inRow) || diffPlans ) { |
| excludeList.add(hop.getHopID()); |
| foundRowOp = true; |
| } |
| |
| //process children recursively |
| for( int i=0; i<hop.getInput().size(); i++ ) { |
| boolean lfoundRowOp = foundRowOp && me != null |
| && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type)); |
| rCollectDependentRowOps(hop.getInput().get(i), memo, |
| part, excludeList, visited, me!=null?me.type:null, lfoundRowOp); |
| } |
| |
| //process node itself (bottom-up) |
| if( !excludeList.contains(hop.getHopID()) ) { |
| for( int i=0; i<hop.getInput().size(); i++ ) |
| if( me != null && me.type == TemplateType.ROW |
| && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type)) |
| && excludeList.contains(hop.getInput().get(i).getHopID()) ) { |
| excludeList.add(hop.getHopID()); |
| } |
| } |
| |
| visited.add(key); |
| } |
| |
| private static boolean isRowAggOp(Hop hop, boolean inRow) { |
| return HopRewriteUtils.isBinary(hop, OpOp2.CBIND) |
| || HopRewriteUtils.isNary(hop, OpOpN.CBIND) |
| || (hop instanceof AggBinaryOp && (inRow || !hop.dimsKnown() |
| || (hop.getDim1()!=1 && hop.getDim2()!=1))) |
| || (HopRewriteUtils.isTransposeOperation(hop) |
| && (hop.getDim1()!=1 && hop.getDim2()!=1) |
| && !HopRewriteUtils.isDataGenOp(hop.getInput().get(0),OpOpDG.SEQ)) |
| || (hop instanceof AggUnaryOp && inRow); |
| } |
| |
| private static boolean isValidRow2CellOp(Hop hop) { |
| return !(HopRewriteUtils.isBinary(hop, OpOp2.CBIND) |
| || (hop instanceof AggBinaryOp && hop.getDim1()!=1 && hop.getDim2()!=1)); |
| } |
| |
| private static void pruneInvalidAndSpecialCasePlans(CPlanMemoTable memo, PlanPartition part) |
| { |
| //prune invalid row entries w/ violated blocksize constraint |
| if( OptimizerUtils.isSparkExecutionMode() ) { |
| for( Long hopID : part.getPartition() ) { |
| if( !memo.contains(hopID, TemplateType.ROW) ) |
| continue; |
| Hop hop = memo.getHopRefs().get(hopID); |
| boolean isSpark = DMLScript.getGlobalExecMode() == ExecMode.SPARK |
| || OptimizerUtils.getTotalMemEstimate(hop.getInput().toArray(new Hop[0]), hop, true) |
| > OptimizerUtils.getLocalMemBudget(); |
| boolean validNcol = hop.getDataType().isScalar() || (HopRewriteUtils.isTransposeOperation(hop) ? |
| hop.getDim1() <= hop.getBlocksize() : hop.getDim2() <= hop.getBlocksize()); |
| for( Hop in : hop.getInput() ) |
| validNcol &= in.getDataType().isScalar() |
| || (in.getDim2() <= in.getBlocksize()) |
| || (hop instanceof AggBinaryOp && in.getDim1() <= in.getBlocksize() |
| && HopRewriteUtils.isTransposeOperation(in)); |
| if( isSpark && !validNcol ) { |
| List<MemoTableEntry> excludeList = memo.get(hopID, TemplateType.ROW); |
| memo.remove(memo.getHopRefs().get(hopID), TemplateType.ROW); |
| memo.removeAllRefTo(hopID, TemplateType.ROW); |
| if( LOG.isTraceEnabled() ) { |
| LOG.trace("Removed row memo table entries w/ violated blocksize constraint ("+hopID+"): " |
| + Arrays.toString(excludeList.toArray(new MemoTableEntry[0]))); |
| } |
| } |
| } |
| } |
| |
| //prune row aggregates with pure cellwise operations |
| //(we determine an excludeList of all operators in a partition that either |
| //depend upon row aggregates or on which row aggregates depend) |
| HashSet<Long> excludeList = collectIrreplaceableRowOps(memo, part); |
| for( Long hopID : part.getPartition() ) { |
| if( excludeList.contains(hopID) ) continue; |
| MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW); |
| if( me != null && me.type == TemplateType.ROW |
| && memo.hasOnlyExactMatches(hopID, TemplateType.ROW, TemplateType.CELL) ) { |
| List<MemoTableEntry> rmList = memo.get(hopID, TemplateType.ROW); |
| memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(rmList)); |
| if( LOG.isTraceEnabled() ) { |
| LOG.trace("Removed row memo table entries w/o aggregation: " |
| + Arrays.toString(rmList.toArray(new MemoTableEntry[0]))); |
| } |
| } |
| } |
| |
| //prune suboptimal outer product plans that are dominated by outer product plans w/ same number of |
| //references but better fusion properties (e.g., for the patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), |
| //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern. |
| for( Long hopID : part.getPartition() ) { |
| if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) { |
| List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER); |
| MemoTableEntry me1 = entries.get(0); |
| MemoTableEntry me2 = entries.get(1); |
| MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2); |
| if( rmEntry != null ) { |
| memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry)); |
| memo.getPlansExcludeListed().remove(rmEntry.input(rmEntry.getPlanRefIndex())); |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Removed dominated outer product memo table entry: " + rmEntry); |
| } |
| } |
| } |
| } |
| |
| private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, |
| PlanPartition part, InterestingPoint[] matPoints, boolean[] plan) |
| { |
| //memoization (not via hops because in middle of dag) |
| if( visited.contains(current.getHopID()) ) |
| return; |
| |
| //remove memo table entries if necessary |
| long hopID = current.getHopID(); |
| if( part.getPartition().contains(hopID) && memo.contains(hopID) ) { |
| Iterator<MemoTableEntry> iter = memo.get(hopID).iterator(); |
| while( iter.hasNext() ) { |
| MemoTableEntry me = iter.next(); |
| if( !hasNoRefToMatPoint(hopID, me, matPoints, plan) && me.type!=TemplateType.OUTER ) { |
| iter.remove(); |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Removed memo table entry: "+me); |
| } |
| } |
| } |
| |
| //process children recursively |
| for( Hop c : current.getInput() ) |
| rPruneSuboptimalPlans(memo, c, visited, part, matPoints, plan); |
| |
| visited.add(current.getHopID()); |
| } |
| |
| private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, PlanPartition part, boolean[] plan) { |
| //memoization (not via hops because in middle of dag) |
| if( visited.contains(current.getHopID()) ) |
| return; |
| |
| //process children recursively |
| for( Hop c : current.getInput() ) |
| rPruneInvalidPlans(memo, c, visited, part, plan); |
| |
| //find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs, |
| //i.e., plans that become invalid after the previous pruning step |
| long hopID = current.getHopID(); |
| if( part.getPartition().contains(hopID) && memo.contains(hopID, TemplateType.ROW) ) { |
| Iterator<MemoTableEntry> iter = memo.get(hopID, TemplateType.ROW).iterator(); |
| while( iter.hasNext() ) { |
| MemoTableEntry me = iter.next(); |
| //convert leaf node with pure vector inputs |
| boolean applyLeaf = (!me.hasPlanRef() |
| && !TemplateUtils.hasMatrixInput(current)); |
| |
| //convert inner node without row template input |
| boolean applyInner = !applyLeaf && !ROW_TPL.open(current); |
| for( int i=0; i<3 & applyInner; i++ ) |
| if( me.isPlanRef(i) ) |
| applyInner &= !memo.contains(me.input(i), TemplateType.ROW); |
| |
| if( applyLeaf || applyInner ) { |
| String type = applyLeaf ? "leaf" : "inner"; |
| if( isValidRow2CellOp(current) ) { |
| me.type = TemplateType.CELL; |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Converted "+type+" memo table entry from row to cell: "+me); |
| } |
| else { |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("Removed "+type+" memo table entry row (unsupported cell): "+me); |
| iter.remove(); |
| } |
| } |
| } |
| } |
| |
| visited.add(current.getHopID()); |
| } |
| |
| ///////////////////////////////////////////////////////// |
| // Cost model fused operators w/ materialization points |
| ////////// |
| |
| private double getPlanCost(CPlanMemoTable memo, PlanPartition part, |
| InterestingPoint[] matPoints,boolean[] plan, HashMap<Long, Double> computeCosts, |
| final double costBound) |
| { |
| //high level heuristic: every hop or fused operator has the following cost: |
| //WRITE + max(COMPUTE, READ), where WRITE costs are given by the output size, |
| //READ costs by the input sizes, and COMPUTE by operation specific FLOP |
| //counts times number of cells of main input, disregarding sparsity for now. |
| |
| HashSet<VisitMarkCost> visited = new HashSet<>(); |
| double costs = 0; |
| int rem = part.getRoots().size(); |
| for( Long hopID : part.getRoots() ) { |
| costs += rGetPlanCosts(memo, memo.getHopRefs().get(hopID), |
| visited, part, matPoints, plan, computeCosts, null, null, costBound-costs); |
| if( costs >= costBound && --rem > 0 ) //stop early |
| return Double.POSITIVE_INFINITY; |
| } |
| return costs; |
| } |
| |
| private double rGetPlanCosts(CPlanMemoTable memo, final Hop current, HashSet<VisitMarkCost> visited, |
| PlanPartition part, InterestingPoint[] matPoints, boolean[] plan, HashMap<Long, Double> computeCosts, |
| CostVector costsCurrent, TemplateType currentType, final double costBound) |
| { |
| final long currentHopId = current.getHopID(); |
| //memoization per hop id and cost vector to account for redundant |
| //computation without double counting materialized results or compute |
| //costs of complex operation DAGs within a single fused operator |
| if( !visited.add(new VisitMarkCost(currentHopId, |
| (costsCurrent==null || currentType==TemplateType.MAGG)?-1:costsCurrent.ID)) ) |
| return 0; //already existing |
| |
| //open template if necessary, including memoization |
| //under awareness of current plan choice |
| MemoTableEntry best = null; |
| boolean opened = (currentType == null); |
| if( memo.contains(currentHopId) ) { |
| //note: this is the inner loop of plan enumeration and hence, we do not |
| //use streams, lambda expressions, etc to avoid unnecessary overhead |
| if( currentType == null ) { |
| for( MemoTableEntry me : memo.get(currentHopId) ) |
| best = me.isValid() |
| && hasNoRefToMatPoint(currentHopId, me, matPoints, plan) |
| && BasicPlanComparator.icompare(me, best)<0 ? me : best; |
| opened = true; |
| } |
| else { |
| for( MemoTableEntry me : memo.get(currentHopId) ) |
| best = (me.type == currentType || me.type==TemplateType.CELL) |
| && hasNoRefToMatPoint(currentHopId, me, matPoints, plan) |
| && TypedPlanComparator.icompare(me, best, currentType)<0 ? me : best; |
| } |
| } |
| |
| //create new cost vector if opened, initialized with write costs |
| CostVector costVect = !opened ? costsCurrent : new CostVector(getSize(current)); |
| double costs = 0; |
| |
| //add other roots for multi-agg template to account for shared costs |
| if( opened && best != null && best.type == TemplateType.MAGG ) { |
| //account costs to first multi-agg root |
| if( best.input1 == currentHopId ) |
| for( int i=1; i<3; i++ ) { |
| if( !best.isPlanRef(i) ) continue; |
| costs += rGetPlanCosts(memo, memo.getHopRefs().get(best.input(i)), visited, |
| part, matPoints, plan, computeCosts, costVect, TemplateType.MAGG, costBound-costs); |
| if( costs >= costBound ) |
| return Double.POSITIVE_INFINITY; |
| } |
| //skip other multi-agg roots |
| else |
| return 0; |
| } |
| |
| //add compute costs of current operator to costs vector |
| if( computeCosts.containsKey(currentHopId) ) |
| costVect.computeCosts += computeCosts.get(currentHopId); |
| |
| //process children recursively |
| for( int i=0; i< current.getInput().size(); i++ ) { |
| Hop c = current.getInput().get(i); |
| if( best!=null && best.isPlanRef(i) ) |
| costs += rGetPlanCosts(memo, c, visited, part, matPoints, |
| plan, computeCosts, costVect, best.type, costBound-costs); |
| else if( best!=null && isImplicitlyFused(current, i, best.type) ) |
| costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c)); |
| else { //include children and I/O costs |
| if( part.getPartition().contains(c.getHopID()) ) |
| costs += rGetPlanCosts(memo, c, visited, part, matPoints, |
| plan, computeCosts, null, null, costBound-costs); |
| if( costVect != null && c.getDataType().isMatrix() ) |
| costVect.addInputSize(c.getHopID(), getSize(c)); |
| } |
| if( costs >= costBound ) |
| return Double.POSITIVE_INFINITY; |
| } |
| |
| //add costs for opened fused operator |
| if( opened ) { |
| double memInputs = sumInputMemoryEstimates(memo, costVect); |
| double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH_MEM |
| + Math.max(memInputs / READ_BANDWIDTH_MEM, |
| costVect.computeCosts/ COMPUTE_BANDWIDTH); |
| //read correction for distributed computation |
| if( memInputs > OptimizerUtils.getLocalMemBudget() ) |
| tmpCosts += costVect.getSideInputSize() * 8 / READ_BANDWIDTH_BROADCAST; |
| //sparsity correction for outer-product template (and sparse-safe cell) |
| Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID()); |
| if( best != null && best.type == TemplateType.OUTER ) |
| tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST; |
| //write correction for known evictions in CP |
| else if( memInputs <= OptimizerUtils.getLocalMemBudget() |
| && sumTmpInputOutputSize(memo, costVect)*8 > LazyWriteBuffer.getWriteBufferLimit() ) |
| tmpCosts += costVect.outSize * 8 / WRITE_BANDWIDTH_IO; |
| costs += tmpCosts; |
| if( LOG.isTraceEnabled() ) { |
| String type = (best !=null) ? best.type.name() : "HOP"; |
| LOG.trace("Cost vector ("+type+" "+currentHopId+"): "+costVect+" -> "+tmpCosts); |
| } |
| } |
| //add costs for non-partition read in the middle of fused operator |
| else if( part.getExtConsumed().contains(current.getHopID()) ) { |
| costs += rGetPlanCosts(memo, current, visited, part, matPoints, plan, |
| computeCosts, null, null, costBound - costs); |
| if( costs >= costBound ) |
| return Double.POSITIVE_INFINITY; |
| } |
| |
| //sanity check non-negative costs |
| if( costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs) ) |
| throw new RuntimeException("Wrong cost estimate: "+costs); |
| |
| return costs; |
| } |
| |
| private static void getComputeCosts(Hop current, HashMap<Long, Double> computeCosts) |
| { |
| //get costs for given hop |
| double costs = 1; |
| if( current instanceof UnaryOp ) { |
| switch( ((UnaryOp)current).getOp() ) { |
| case ABS: |
| case ROUND: |
| case CEIL: |
| case FLOOR: |
| case SIGN: costs = 1; break; |
| case SPROP: |
| case SQRT: costs = 2; break; |
| case EXP: costs = 18; break; |
| case SIGMOID: costs = 21; break; |
| case LOG: |
| case LOG_NZ: costs = 32; break; |
| case NCOL: |
| case NROW: |
| case PRINT: |
| case ASSERT: |
| case CAST_AS_BOOLEAN: |
| case CAST_AS_DOUBLE: |
| case CAST_AS_INT: |
| case CAST_AS_MATRIX: |
| case CAST_AS_SCALAR: costs = 1; break; |
| case SIN: costs = 18; break; |
| case COS: costs = 22; break; |
| case TAN: costs = 42; break; |
| case ASIN: costs = 93; break; |
| case ACOS: costs = 103; break; |
| case ATAN: costs = 40; break; |
| case SINH: costs = 93; break; // TODO: |
| case COSH: costs = 103; break; |
| case TANH: costs = 40; break; |
| case CUMSUM: |
| case CUMMIN: |
| case CUMMAX: |
| case CUMPROD: costs = 1; break; |
| case CUMSUMPROD: costs = 2; break; |
| default: |
| LOG.warn("Cost model not " |
| + "implemented yet for: "+((UnaryOp)current).getOp()); |
| } |
| } |
| else if( current instanceof BinaryOp ) { |
| switch( ((BinaryOp)current).getOp() ) { |
| case MULT: |
| case PLUS: |
| case MINUS: |
| case MIN: |
| case MAX: |
| case AND: |
| case OR: |
| case EQUAL: |
| case NOTEQUAL: |
| case LESS: |
| case LESSEQUAL: |
| case GREATER: |
| case GREATEREQUAL: |
| case CBIND: |
| case RBIND: costs = 1; break; |
| case INTDIV: costs = 6; break; |
| case MODULUS: costs = 8; break; |
| case DIV: costs = 22; break; |
| case LOG: |
| case LOG_NZ: costs = 32; break; |
| case POW: costs = (HopRewriteUtils.isLiteralOfValue( |
| current.getInput().get(1), 2) ? 1 : 16); break; |
| case MINUS_NZ: |
| case MINUS1_MULT: costs = 2; break; |
| case MOMENT: |
| int type = (int) (current.getInput().get(1) instanceof LiteralOp ? |
| HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2); |
| switch( type ) { |
| case 0: costs = 1; break; //count |
| case 1: costs = 8; break; //mean |
| case 2: costs = 16; break; //cm2 |
| case 3: costs = 31; break; //cm3 |
| case 4: costs = 51; break; //cm4 |
| case 5: costs = 16; break; //variance |
| } |
| break; |
| case COV: costs = 23; break; |
| default: |
| LOG.warn("Cost model not " |
| + "implemented yet for: "+((BinaryOp)current).getOp()); |
| } |
| } |
| else if( current instanceof TernaryOp ) { |
| switch( ((TernaryOp)current).getOp() ) { |
| case IFELSE: |
| case PLUS_MULT: |
| case MINUS_MULT: costs = 2; break; |
| case CTABLE: costs = 3; break; |
| case MOMENT: |
| int type = (int) (current.getInput().get(1) instanceof LiteralOp ? |
| HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2); |
| switch( type ) { |
| case 0: costs = 2; break; //count |
| case 1: costs = 9; break; //mean |
| case 2: costs = 17; break; //cm2 |
| case 3: costs = 32; break; //cm3 |
| case 4: costs = 52; break; //cm4 |
| case 5: costs = 17; break; //variance |
| } |
| break; |
| case COV: costs = 23; break; |
| default: |
| LOG.warn("Cost model not " |
| + "implemented yet for: "+((TernaryOp)current).getOp()); |
| } |
| } |
| else if( current instanceof NaryOp ) { |
| costs = HopRewriteUtils.isNary(current, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS) ? |
| current.getInput().size() : 1; |
| } |
| else if( current instanceof ParameterizedBuiltinOp ) { |
| costs = 1; |
| } |
| else if( current instanceof IndexingOp ) { |
| costs = 1; |
| } |
| else if( current instanceof ReorgOp ) { |
| costs = 1; |
| } |
| else if( current instanceof DnnOp ) { |
| switch( ((DnnOp)current).getOp() ) { |
| case BIASADD: |
| case BIASMULT: |
| costs = 2; |
| default: |
| LOG.warn("Cost model not " |
| + "implemented yet for: "+((DnnOp)current).getOp()); |
| } |
| } |
| else if( current instanceof AggBinaryOp ) { |
| //outer product template w/ matrix-matrix |
| //or row template w/ matrix-vector or matrix-matrix |
| costs = 2 * current.getInput().get(0).getDim2(); |
| if( current.getInput().get(0).dimsKnown(true) ) |
| costs *= current.getInput().get(0).getSparsity(); |
| } |
| else if( current instanceof AggUnaryOp) { |
| switch(((AggUnaryOp)current).getOp()) { |
| case SUM: costs = 4; break; |
| case SUM_SQ: costs = 5; break; |
| case MIN: |
| case MAX: costs = 1; break; |
| default: |
| LOG.warn("Cost model not " |
| + "implemented yet for: "+((AggUnaryOp)current).getOp()); |
| } |
| switch(((AggUnaryOp)current).getDirection()) { |
| case Col: costs *= Math.max(current.getInput().get(0).getDim1(),1); break; |
| case Row: costs *= Math.max(current.getInput().get(0).getDim2(),1); break; |
| case RowCol: costs *= getSize(current.getInput().get(0)); break; |
| } |
| } |
| |
| //scale by current output size in order to correctly reflect |
| //a mix of row and cell operations in the same fused operator |
| //(e.g., row template with fused column vector operations) |
| costs *= getSize(current); |
| |
| computeCosts.put(current.getHopID(), costs); |
| } |
| |
| private static boolean hasNoRefToMatPoint(long hopID, |
| MemoTableEntry me, InterestingPoint[] M, boolean[] plan) { |
| return !InterestingPoint.isMatPoint(M, hopID, me, plan); |
| } |
| |
| private static boolean isImplicitlyFused(Hop hop, int index, TemplateType type) { |
| return type == TemplateType.ROW |
| && HopRewriteUtils.isMatrixMultiply(hop) && index==0 |
| && HopRewriteUtils.isTransposeOperation(hop.getInput().get(index)); |
| } |
| |
| private static boolean probePlanCache(InterestingPoint[] matPoints) { |
| return matPoints.length >= PLAN_CACHE_NUM_POINTS; |
| } |
| |
| private static boolean[] getPlan(PartitionSignature pKey) { |
| boolean[] plan = null; |
| synchronized( _planCache ) { |
| plan = _planCache.get(pKey); |
| } |
| if( DMLScript.STATISTICS ) { |
| if( plan != null ) |
| Statistics.incrementCodegenPlanCacheHits(); |
| Statistics.incrementCodegenPlanCacheTotal(); |
| } |
| return plan; |
| } |
| |
| private static void putPlan(PartitionSignature pKey, boolean[] plan) { |
| synchronized( _planCache ) { |
| //maintain size of plan cache (remove first) |
| if( _planCache.size() >= PLAN_CACHE_SIZE ) { |
| Iterator<Entry<PartitionSignature, boolean[]>> iter = |
| _planCache.entrySet().iterator(); |
| iter.next(); |
| iter.remove(); |
| } |
| |
| //add last entry |
| _planCache.put(pKey, plan); |
| } |
| } |
| |
| private class CostVector { |
| public final long ID; |
| public final double outSize; |
| public double computeCosts = 0; |
| public final HashMap<Long, Double> inSizes = new HashMap<>(); |
| |
| public CostVector(double outputSize) { |
| ID = COST_ID.getNextID(); |
| outSize = outputSize; |
| } |
| public void addInputSize(long hopID, double inputSize) { |
| //ensures that input sizes are not double counted |
| inSizes.put(hopID, inputSize); |
| } |
| @SuppressWarnings("unused") |
| public double getInputSize() { |
| return inSizes.values().stream() |
| .mapToDouble(d -> d.doubleValue()).sum(); |
| } |
| public double getSideInputSize() { |
| double max = getMaxInputSize(); |
| return inSizes.values().stream() |
| .filter(d -> d < max) |
| .mapToDouble(d -> d.doubleValue()).sum(); |
| } |
| public double getMaxInputSize() { |
| return inSizes.values().stream() |
| .mapToDouble(d -> d.doubleValue()).max().orElse(0); |
| } |
| public long getMaxInputSizeHopID() { |
| long id = -1; double max = 0; |
| for( Entry<Long,Double> e : inSizes.entrySet() ) |
| if( max < e.getValue() ) { |
| id = e.getKey(); |
| max = e.getValue(); |
| } |
| return id; |
| } |
| @Override |
| public String toString() { |
| return "["+outSize+", "+computeCosts+", {" |
| +Arrays.toString(inSizes.keySet().toArray(new Long[0]))+", " |
| +Arrays.toString(inSizes.values().toArray(new Double[0]))+"}]"; |
| } |
| } |
| |
| private static class StaticCosts { |
| public final HashMap<Long, Double> _computeCosts; |
| public final double _compute; |
| public final double _read; |
| public final double _write; |
| public final double _minSparsity; |
| public StaticCosts(HashMap<Long,Double> allComputeCosts, double computeCost, double readCost, double writeCost, double minSparsity) { |
| _computeCosts = allComputeCosts; |
| _compute = computeCost; |
| _read = readCost; |
| _write = writeCost; |
| _minSparsity = minSparsity; |
| } |
| public double getMinCosts() { |
| return Math.max(_read, _compute) + _write; |
| } |
| } |
| |
| private static class AggregateInfo { |
| public final HashMap<Long,Hop> _aggregates; |
| public final HashSet<Long> _inputAggs = new HashSet<>(); |
| public final HashSet<Long> _fusedInputs = new HashSet<>(); |
| public AggregateInfo(Hop aggregate) { |
| _aggregates = new HashMap<>(); |
| _aggregates.put(aggregate.getHopID(), aggregate); |
| } |
| public void addInputAggregate(long hopID) { |
| _inputAggs.add(hopID); |
| } |
| public void addFusedInput(long hopID) { |
| _fusedInputs.add(hopID); |
| } |
| public boolean isMergable(AggregateInfo that) { |
| //check independence |
| boolean ret = _aggregates.size()<3 |
| && _aggregates.size()+that._aggregates.size()<=3; |
| for( Long hopID : that._aggregates.keySet() ) |
| ret &= !_inputAggs.contains(hopID); |
| for( Long hopID : _aggregates.keySet() ) |
| ret &= !that._inputAggs.contains(hopID); |
| //check partial shared reads |
| ret &= CollectionUtils.containsAny(_fusedInputs, that._fusedInputs); |
| //check consistent sizes (result correctness) |
| Hop in1 = _aggregates.values().iterator().next(); |
| Hop in2 = that._aggregates.values().iterator().next(); |
| return ret && HopRewriteUtils.isEqualSize( |
| in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0), |
| in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0)); |
| } |
| public AggregateInfo merge(AggregateInfo that) { |
| _aggregates.putAll(that._aggregates); |
| _inputAggs.addAll(that._inputAggs); |
| _fusedInputs.addAll(that._fusedInputs); |
| return this; |
| } |
| @Override |
| public String toString() { |
| return "["+Arrays.toString(_aggregates.keySet().toArray(new Long[0]))+": " |
| +"{"+Arrays.toString(_inputAggs.toArray(new Long[0]))+"}," |
| +"{"+Arrays.toString(_fusedInputs.toArray(new Long[0]))+"}]"; |
| } |
| } |
| |
| private static class PartitionSignature { |
| private final int partNodes, inputNodes, rootNodes, matPoints; |
| private final double cCompute, cRead, cWrite, cPlan0, cPlanN; |
| |
| public PartitionSignature(PlanPartition part, int M, StaticCosts costs, double cP0, double cPN) { |
| partNodes = part.getPartition().size(); |
| inputNodes = part.getInputs().size(); |
| rootNodes = part.getRoots().size(); |
| matPoints = M; |
| cCompute = costs._compute; |
| cRead = costs._read; |
| cWrite = costs._write; |
| cPlan0 = cP0; |
| cPlanN = cPN; |
| } |
| @Override |
| public int hashCode() { |
| return UtilFunctions.intHashCode( |
| Arrays.hashCode(new int[]{partNodes, inputNodes, rootNodes, matPoints}), |
| Arrays.hashCode(new double[]{cCompute, cRead, cWrite, cPlan0, cPlanN})); |
| } |
| @Override |
| public boolean equals(Object o) { |
| if( !(o instanceof PartitionSignature) ) |
| return false; |
| PartitionSignature that = (PartitionSignature) o; |
| return partNodes == that.partNodes |
| && inputNodes == that.inputNodes |
| && rootNodes == that.rootNodes |
| && matPoints == that.matPoints |
| && cCompute == that.cCompute |
| && cRead == that.cRead |
| && cWrite == that.cWrite |
| && cPlan0 == that.cPlan0 |
| && cPlanN == that.cPlanN; |
| } |
| } |
| } |