blob: e19b5a5c6ed5c7d00e9a68fbc15dbed1b8f4063f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.codegen.opt;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateOuterProduct;
import org.apache.sysds.hops.codegen.template.TemplateRow;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
import org.apache.sysds.hops.codegen.template.TemplateBase.TemplateType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
/**
* This cost-based plan selection algorithm chooses fused operators
* based on the DAG structure and resulting overall costs. This primarily
* includes decisions on materialization points, but also heuristics for
* template types, and composed multi output templates.
*
*/
public class PlanSelectionFuseCostBased extends PlanSelection
{
private static final Log LOG = LogFactory.getLog(PlanSelectionFuseCostBased.class.getName());
//common bandwidth characteristics, with a conservative write bandwidth in order
//to cover result allocation, write into main memory, and potential evictions
private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024; //2GB/s
private static final double READ_BANDWIDTH = 32d*1024*1024*1024; //32GB/s
private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //2GFLOPs/core
* InfrastructureAnalyzer.getLocalParallelism();
private static final IDSequence COST_ID = new IDSequence();
private static final TemplateRow ROW_TPL = new TemplateRow();
@Override
public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots)
{
//step 1: analyze connected partitions (nodes, roots, mat points)
Collection<PlanPartition> parts = PlanAnalyzer.analyzePlanPartitions(memo, roots, false);
//step 2: optimize individual plan partitions
int sumMatPoints = 0;
for( PlanPartition part : parts ) {
//create composite templates (within the partition)
createAndAddMultiAggPlans(memo, part.getPartition(), part.getRoots());
//plan enumeration and plan selection
selectPlans(memo, part.getPartition(), part.getRoots(), part.getMatPoints());
sumMatPoints += part.getMatPoints().size();
}
//step 3: add composite templates (across partitions)
createAndAddMultiAggPlans(memo, roots);
//take all distinct best plans
for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() )
memo.setDistinct(e.getKey(), e.getValue());
//maintain statistics
if( DMLScript.STATISTICS )
Statistics.incrementCodegenEnumAll(UtilFunctions.pow(2, sumMatPoints));
}
//within-partition multi-agg templates
private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R)
{
//create index of plans that reference full aggregates to avoid circular dependencies
HashSet<Long> refHops = new HashSet<>();
for( Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet() )
if( !e.getValue().isEmpty() ) {
Hop hop = memo.getHopRefs().get(e.getKey());
for( Hop c : hop.getInput() )
refHops.add(c.getHopID());
}
//find all full aggregations (the fact that they are in the same partition guarantees
//that they also have common subexpressions, also full aggregations are by def root nodes)
ArrayList<Long> fullAggs = new ArrayList<>();
for( Long hopID : R ) {
Hop root = memo.getHopRefs().get(hopID);
if( !refHops.contains(hopID) && isMultiAggregateRoot(root) )
fullAggs.add(hopID);
}
if( LOG.isTraceEnabled() ) {
LOG.trace("Found within-partition ua(RC) aggregations: " +
Arrays.toString(fullAggs.toArray(new Long[0])));
}
//construct and add multiagg template plans (w/ max 3 aggregations)
for( int i=0; i<fullAggs.size(); i+=3 ) {
int ito = Math.min(i+3, fullAggs.size());
if( ito-i >= 2 ) {
MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG,
fullAggs.get(i), fullAggs.get(i+1), ((ito-i)==3)?fullAggs.get(i+2):-1, ito-i);
if( isValidMultiAggregate(memo, me) ) {
for( int j=i; j<ito; j++ ) {
memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
if( LOG.isTraceEnabled() )
LOG.trace("Added multiagg plan: "+fullAggs.get(j)+" "+me);
}
}
else if( LOG.isTraceEnabled() ) {
LOG.trace("Removed invalid multiagg plan: "+me);
}
}
}
}
//across-partition multi-agg templates with shared reads
private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots)
{
//collect full aggregations as initial set of candidates
HashSet<Long> fullAggs = new HashSet<>();
Hop.resetVisitStatus(roots);
for( Hop hop : roots )
rCollectFullAggregates(hop, fullAggs);
Hop.resetVisitStatus(roots);
//remove operators with assigned multi-agg plans
fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));
//check applicability for further analysis
if( fullAggs.size() <= 1 )
return;
if( LOG.isTraceEnabled() ) {
LOG.trace("Found across-partition ua(RC) aggregations: " +
Arrays.toString(fullAggs.toArray(new Long[0])));
}
//collect information for all candidates
//(subsumed aggregations, and inputs to fused operators)
List<AggregateInfo> aggInfos = new ArrayList<>();
for( Long hopID : fullAggs ) {
Hop aggHop = memo.getHopRefs().get(hopID);
AggregateInfo tmp = new AggregateInfo(aggHop);
for( int i=0; i<aggHop.getInput().size(); i++ ) {
Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ?
aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL);
}
if( tmp._fusedInputs.isEmpty() ) {
if( HopRewriteUtils.isMatrixMultiply(aggHop) ) {
tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
}
else
tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
}
aggInfos.add(tmp);
}
if( LOG.isTraceEnabled() ) {
LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
for( AggregateInfo info : aggInfos )
LOG.trace(info);
}
//sort aggregations by num dependencies to simplify merging
//clusters of aggregations with parallel dependencies
aggInfos = aggInfos.stream()
.sorted(Comparator.comparing(a -> a._inputAggs.size()))
.collect(Collectors.toList());
//greedy grouping of multi-agg candidates
boolean converged = false;
while( !converged ) {
AggregateInfo merged = null;
for( int i=0; i<aggInfos.size(); i++ ) {
AggregateInfo current = aggInfos.get(i);
for( int j=i+1; j<aggInfos.size(); j++ ) {
AggregateInfo that = aggInfos.get(j);
if( current.isMergable(that) ) {
merged = current.merge(that);
aggInfos.remove(j); j--;
}
}
}
converged = (merged == null);
}
if( LOG.isTraceEnabled() ) {
LOG.trace("Merged across-partition ua(RC) aggregation info: ");
for( AggregateInfo info : aggInfos )
LOG.trace(info);
}
//construct and add multiagg template plans (w/ max 3 aggregations)
for( AggregateInfo info : aggInfos ) {
if( info._aggregates.size()<=1 )
continue;
Long[] aggs = info._aggregates.keySet().toArray(new Long[0]);
MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG,
aggs[0], aggs[1], (aggs.length>2)?aggs[2]:-1, aggs.length);
for( int i=0; i<aggs.length; i++ ) {
memo.add(memo.getHopRefs().get(aggs[i]), me);
addBestPlan(aggs[i], me);
if( LOG.isTraceEnabled() )
LOG.trace("Added multiagg* plan: "+aggs[i]+" "+me);
}
}
}
private static boolean isMultiAggregateRoot(Hop root) {
return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX)
&& ((AggUnaryOp)root).getDirection()==Direction.RowCol)
|| (root instanceof AggBinaryOp && root.getDim1()==1 && root.getDim2()==1
&& HopRewriteUtils.isTransposeOperation(root.getInput().get(0)));
}
private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) {
//ensure input consistent sizes (otherwise potential for incorrect results)
boolean ret = true;
Hop refSize = memo.getHopRefs().get(me.input1).getInput().get(0);
for( int i=1; ret && i<3; i++ ) {
if( me.isPlanRef(i) )
ret &= HopRewriteUtils.isEqualSize(refSize,
memo.getHopRefs().get(me.input(i)).getInput().get(0));
}
//ensure that aggregates are independent of each other, i.e.,
//they to not have potentially transitive parent child references
for( int i=0; ret && i<3; i++ )
if( me.isPlanRef(i) ) {
HashSet<Long> probe = new HashSet<>();
for( int j=0; j<3; j++ )
if( i != j )
probe.add(me.input(j));
ret &= rCheckMultiAggregate(memo.getHopRefs().get(me.input(i)), probe);
}
return ret;
}
private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> probe) {
boolean ret = true;
for( Hop c : current.getInput() )
ret &= rCheckMultiAggregate(c, probe);
ret &= !probe.contains(current.getHopID());
return ret;
}
private static void rCollectFullAggregates(Hop current, HashSet<Long> aggs) {
if( current.isVisited() )
return;
//collect all applicable full aggregations per read
if( isMultiAggregateRoot(current) )
aggs.add(current.getHopID());
//recursively process children
for( Hop c : current.getInput() )
rCollectFullAggregates(c, aggs);
current.setVisited();
}
private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo, TemplateType type) {
//collect input aggregates (dependents)
if( isMultiAggregateRoot(current) )
aggInfo.addInputAggregate(current.getHopID());
//recursively process children
MemoTableEntry me = (type!=null) ? memo.getBest(current.getHopID()) : null;
for( int i=0; i<current.getInput().size(); i++ ) {
Hop c = current.getInput().get(i);
if( me != null && me.isPlanRef(i) )
rExtractAggregateInfo(memo, c, aggInfo, type);
else {
if( type != null && c.getDataType().isMatrix() ) //add fused input
aggInfo.addFusedInput(c.getHopID());
rExtractAggregateInfo(memo, c, aggInfo, null);
}
}
}
private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M)
{
//prune row aggregates with pure cellwise operations
for( Long hopID : R ) {
MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
if( me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL)
&& isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) {
List<MemoTableEntry> excludeList = memo.get(hopID, TemplateType.ROW);
memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(excludeList));
if( LOG.isTraceEnabled() ) {
LOG.trace("Removed row memo table entries w/o aggregation: "
+ Arrays.toString(excludeList.toArray(new MemoTableEntry[0])));
}
}
}
//prune suboptimal outer product plans that are dominated by outer product plans w/ same number of
//references but better fusion properties (e.g., for the patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))),
//we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern.
for( Long hopID : partition ) {
if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) {
List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER);
MemoTableEntry me1 = entries.get(0);
MemoTableEntry me2 = entries.get(1);
MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
if( rmEntry != null ) {
memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
memo.getPlansExcludeListed().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
if( LOG.isTraceEnabled() )
LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
}
}
}
//if no materialization points, use basic fuse-all w/ partition awareness
if( M == null || M.isEmpty() ) {
for( Long hopID : R )
rSelectPlansFuseAll(memo,
memo.getHopRefs().get(hopID), null, partition);
}
else {
//TODO branch and bound pruning, right now we use exhaustive enum for early experiments
//via skip ahead in below enumeration algorithm
//obtain hop compute costs per cell once
HashMap<Long, Double> computeCosts = new HashMap<>();
for( Long hopID : R )
rGetComputeCosts(memo.getHopRefs().get(hopID), partition, computeCosts);
//scan linearized search space, w/ skips for branch and bound pruning
int len = (int)Math.pow(2, M.size());
boolean[] bestPlan = null;
double bestC = Double.MAX_VALUE;
for( int i=0; i<len; i++ ) {
//construct assignment
boolean[] plan = createAssignment(M.size(), i);
//cost assignment on hops
double C = getPlanCost(memo, partition, R, M, plan, computeCosts);
if( LOG.isTraceEnabled() )
LOG.trace("Enum: "+Arrays.toString(plan)+" -> "+C);
//cost comparisons
if( bestPlan == null || C < bestC ) {
bestC = C;
bestPlan = plan;
if( LOG.isTraceEnabled() )
LOG.trace("Enum: Found new best plan.");
}
}
if( DMLScript.STATISTICS ) {
Statistics.incrementCodegenEnumAllP(len);
Statistics.incrementCodegenEnumEval(len);
}
//prune memo table wrt best plan and select plans
HashSet<Long> visited = new HashSet<>();
for( Long hopID : R )
rPruneSuboptimalPlans(memo, memo.getHopRefs().get(hopID),
visited, partition, M, bestPlan);
HashSet<Long> visited2 = new HashSet<>();
for( Long hopID : R )
rPruneInvalidPlans(memo, memo.getHopRefs().get(hopID),
visited2, partition, M, bestPlan);
for( Long hopID : R )
rSelectPlansFuseAll(memo,
memo.getHopRefs().get(hopID), null, partition);
}
}
private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
//consider all aggregations other than root operation
MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
boolean ret = true;
for(int i=0; i<3; i++)
if( me.isPlanRef(i) )
ret &= rIsRowTemplateWithoutAgg(memo,
current.getInput().get(i), visited);
return ret;
}
private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
if( visited.contains(current.getHopID()) )
return true;
boolean ret = true;
MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
for(int i=0; i<3; i++)
if( me.isPlanRef(i) )
ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited);
ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp);
visited.add(current.getHopID());
return ret;
}
private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) {
//memoization (not via hops because in middle of dag)
if( visited.contains(current.getHopID()) )
return;
//remove memo table entries if necessary
long hopID = current.getHopID();
if( partition.contains(hopID) && memo.contains(hopID) ) {
Iterator<MemoTableEntry> iter = memo.get(hopID).iterator();
while( iter.hasNext() ) {
MemoTableEntry me = iter.next();
if( !hasNoRefToMaterialization(me, M, plan) && me.type!=TemplateType.OUTER ){
iter.remove();
if( LOG.isTraceEnabled() )
LOG.trace("Removed memo table entry: "+me);
}
}
}
//process children recursively
for( Hop c : current.getInput() )
rPruneSuboptimalPlans(memo, c, visited, partition, M, plan);
visited.add(current.getHopID());
}
private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) {
//memoization (not via hops because in middle of dag)
if( visited.contains(current.getHopID()) )
return;
//process children recursively
for( Hop c : current.getInput() )
rPruneInvalidPlans(memo, c, visited, partition, M, plan);
//find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs,
//i.e., plans that become invalid after the previous pruning step
long hopID = current.getHopID();
if( partition.contains(hopID) && memo.contains(hopID, TemplateType.ROW) ) {
for( MemoTableEntry me : memo.get(hopID) ) {
if( me.type==TemplateType.ROW ) {
//convert leaf node with pure vector inputs
if( !me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current) ) {
me.type = TemplateType.CELL;
if( LOG.isTraceEnabled() )
LOG.trace("Converted leaf memo table entry from row to cell: "+me);
}
//convert inner node without row template input
if( me.hasPlanRef() && !ROW_TPL.open(current) ) {
boolean hasRowInput = false;
for( int i=0; i<3; i++ )
if( me.isPlanRef(i) )
hasRowInput |= memo.contains(me.input(i), TemplateType.ROW);
if( !hasRowInput ) {
me.type = TemplateType.CELL;
if( LOG.isTraceEnabled() )
LOG.trace("Converted inner memo table entry from row to cell: "+me);
}
}
}
}
}
visited.add(current.getHopID());
}
private static boolean[] createAssignment(int len, int pos) {
boolean[] ret = new boolean[len];
int tmp = pos;
for( int i=0; i<len; i++ ) {
ret[i] = (tmp < (int)Math.pow(2, len-i-1));
tmp %= Math.pow(2, len-i-1);
}
return ret;
}
/////////////////////////////////////////////////////////
// Cost model fused operators w/ materialization points
//////////
private static double getPlanCost(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R,
ArrayList<Long> M, boolean[] plan, HashMap<Long, Double> computeCosts)
{
//high level heuristic: every hop or fused operator has the following cost:
//WRITE + max(COMPUTE, READ), where WRITE costs are given by the output size,
//READ costs by the input sizes, and COMPUTE by operation specific FLOP
//counts times number of cells of main input, disregarding sparsity for now.
HashSet<Pair<Long,Long>> visited = new HashSet<>();
double costs = 0;
for( Long hopID : R )
costs += rGetPlanCosts(memo, memo.getHopRefs().get(hopID),
visited, partition, M, plan, computeCosts, null, null);
return costs;
}
private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<Pair<Long,Long>> visited, HashSet<Long> partition,
ArrayList<Long> M, boolean[] plan, HashMap<Long, Double> computeCosts, CostVector costsCurrent, TemplateType currentType)
{
//memoization per hop id and cost vector to account for redundant
//computation without double counting materialized results or compute
//costs of complex operation DAGs within a single fused operator
Pair<Long,Long> tag = Pair.of(current.getHopID(),
(costsCurrent==null)?0:costsCurrent.ID);
if( visited.contains(tag) )
return 0;
visited.add(tag);
//open template if necessary, including memoization
//under awareness of current plan choice
MemoTableEntry best = null;
boolean opened = false;
if( memo.contains(current.getHopID()) ) {
if( currentType == null ) {
best = memo.get(current.getHopID()).stream()
.filter(p -> p.isValid())
.filter(p -> hasNoRefToMaterialization(p, M, plan))
.min(new BasicPlanComparator()).orElse(null);
opened = true;
}
else {
best = memo.get(current.getHopID()).stream()
.filter(p -> p.type==currentType || p.type==TemplateType.CELL)
.filter(p -> hasNoRefToMaterialization(p, M, plan))
.min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs()))
.orElse(null);
}
}
//create new cost vector if opened, initialized with write costs
CostVector costVect = !opened ? costsCurrent :
new CostVector(Math.max(current.getDim1(),1)*Math.max(current.getDim2(),1));
//add compute costs of current operator to costs vector
if( partition.contains(current.getHopID()) )
costVect.computeCosts += computeCosts.get(current.getHopID());
//process children recursively
double costs = 0;
for( int i=0; i< current.getInput().size(); i++ ) {
Hop c = current.getInput().get(i);
if( best!=null && best.isPlanRef(i) )
costs += rGetPlanCosts(memo, c, visited, partition, M, plan, computeCosts, costVect, best.type);
else if( best!=null && isImplicitlyFused(current, i, best.type) )
costVect.addInputSize(c.getInput().get(0).getHopID(), Math.max(c.getDim1(),1)*Math.max(c.getDim2(),1));
else { //include children and I/O costs
costs += rGetPlanCosts(memo, c, visited, partition, M, plan, computeCosts, null, null);
if( costVect != null && c.getDataType().isMatrix() )
costVect.addInputSize(c.getHopID(), Math.max(c.getDim1(),1)*Math.max(c.getDim2(),1));
}
}
//add costs for opened fused operator
if( partition.contains(current.getHopID()) ) {
if( opened ) {
if( LOG.isTraceEnabled() )
LOG.trace("Cost vector for fused operator (hop "+current.getHopID()+"): "+costVect);
costs += costVect.outSize * 8 / WRITE_BANDWIDTH; //time for output write
costs += Math.max(
costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH,
costVect.getSumInputSizes() * 8 / READ_BANDWIDTH);
}
//add costs for non-partition read in the middle of fused operator
else if( hasNonPartitionConsumer(current, partition) ) {
costs += rGetPlanCosts(memo, current, visited, partition, M, plan, computeCosts, null, null);
}
}
//sanity check non-negative costs
if( costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs) )
throw new RuntimeException("Wrong cost estimate: "+costs);
return costs;
}
private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts)
{
if( computeCosts.containsKey(current.getHopID()) )
return;
//recursively process children
for( Hop c : current.getInput() )
rGetComputeCosts(c, partition, computeCosts);
//get costs for given hop
double costs = 1;
if( current instanceof UnaryOp ) {
switch( ((UnaryOp)current).getOp() ) {
case ABS:
case ROUND:
case CEIL:
case FLOOR:
case SIGN: costs = 1; break;
case SPROP:
case SQRT: costs = 2; break;
case EXP: costs = 18; break;
case SIGMOID: costs = 21; break;
case LOG:
case LOG_NZ: costs = 32; break;
case NCOL:
case NROW:
case PRINT:
case ASSERT:
case CAST_AS_BOOLEAN:
case CAST_AS_DOUBLE:
case CAST_AS_INT:
case CAST_AS_MATRIX:
case CAST_AS_SCALAR: costs = 1; break;
case SIN: costs = 18; break;
case COS: costs = 22; break;
case TAN: costs = 42; break;
case ASIN: costs = 93; break;
case ACOS: costs = 103; break;
case ATAN: costs = 40; break;
case SINH: costs = 93; break; // TODO:
case COSH: costs = 103; break;
case TANH: costs = 40; break;
case CUMSUM:
case CUMMIN:
case CUMMAX:
case CUMPROD: costs = 1; break;
case CUMSUMPROD: costs = 2; break;
default:
LOG.warn("Cost model not "
+ "implemented yet for: "+((UnaryOp)current).getOp());
}
}
else if( current instanceof BinaryOp ) {
switch( ((BinaryOp)current).getOp() ) {
case MULT:
case PLUS:
case MINUS:
case MIN:
case MAX:
case AND:
case OR:
case EQUAL:
case NOTEQUAL:
case LESS:
case LESSEQUAL:
case GREATER:
case GREATEREQUAL:
case CBIND:
case RBIND: costs = 1; break;
case INTDIV: costs = 6; break;
case MODULUS: costs = 8; break;
case DIV: costs = 22; break;
case LOG:
case LOG_NZ: costs = 32; break;
case POW: costs = (HopRewriteUtils.isLiteralOfValue(
current.getInput().get(1), 2) ? 1 : 16); break;
case MINUS_NZ:
case MINUS1_MULT: costs = 2; break;
case MOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ?
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
switch( type ) {
case 0: costs = 1; break; //count
case 1: costs = 8; break; //mean
case 2: costs = 16; break; //cm2
case 3: costs = 31; break; //cm3
case 4: costs = 51; break; //cm4
case 5: costs = 16; break; //variance
}
break;
case COV: costs = 23; break;
default:
LOG.warn("Cost model not "
+ "implemented yet for: "+((BinaryOp)current).getOp());
}
}
else if( current instanceof TernaryOp ) {
switch( ((TernaryOp)current).getOp() ) {
case PLUS_MULT:
case MINUS_MULT: costs = 2; break;
case CTABLE: costs = 3; break;
case MOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ?
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
switch( type ) {
case 0: costs = 2; break; //count
case 1: costs = 9; break; //mean
case 2: costs = 17; break; //cm2
case 3: costs = 32; break; //cm3
case 4: costs = 52; break; //cm4
case 5: costs = 17; break; //variance
}
break;
case COV: costs = 23; break;
default:
LOG.warn("Cost model not "
+ "implemented yet for: "+((TernaryOp)current).getOp());
}
}
else if( current instanceof ParameterizedBuiltinOp ) {
costs = 1;
}
else if( current instanceof IndexingOp ) {
costs = 1;
}
else if( current instanceof ReorgOp ) {
costs = 1;
}
else if( current instanceof AggBinaryOp ) {
costs = 2; //matrix vector
}
else if( current instanceof AggUnaryOp) {
switch(((AggUnaryOp)current).getOp()) {
case SUM: costs = 4; break;
case SUM_SQ: costs = 5; break;
case MIN:
case MAX: costs = 1; break;
default:
LOG.warn("Cost model not "
+ "implemented yet for: "+((AggUnaryOp)current).getOp());
}
}
computeCosts.put(current.getHopID(), costs);
}
private static boolean hasNoRefToMaterialization(MemoTableEntry me, ArrayList<Long> M, boolean[] plan) {
boolean ret = true;
for( int i=0; ret && i<3; i++ )
ret &= (!M.contains(me.input(i)) || !plan[M.indexOf(me.input(i))]);
return ret;
}
private static boolean hasNonPartitionConsumer(Hop hop, HashSet<Long> partition) {
boolean ret = false;
for( Hop p : hop.getParent() )
ret |= !partition.contains(p.getHopID());
return ret;
}
private static boolean isImplicitlyFused(Hop hop, int index, TemplateType type) {
return type == TemplateType.ROW
&& HopRewriteUtils.isMatrixMultiply(hop) && index==0
&& HopRewriteUtils.isTransposeOperation(hop.getInput().get(index));
}
private static class CostVector {
public final long ID;
public final double outSize;
public double computeCosts = 0;
public final HashMap<Long, Double> inSizes = new HashMap<>();
public CostVector(double outputSize) {
ID = COST_ID.getNextID();
outSize = outputSize;
}
public void addInputSize(long hopID, double inputSize) {
//ensures that input sizes are not double counted
inSizes.put(hopID, inputSize);
}
public double getSumInputSizes() {
return inSizes.values().stream()
.mapToDouble(d -> d.doubleValue()).sum();
}
public double getMaxInputSize() {
return inSizes.values().stream()
.mapToDouble(d -> d.doubleValue()).max().orElse(0);
}
@Override
public String toString() {
return "["+outSize+", "+computeCosts+", {"
+Arrays.toString(inSizes.keySet().toArray(new Long[0]))+", "
+Arrays.toString(inSizes.values().toArray(new Double[0]))+"}]";
}
}
private static class AggregateInfo {
public final HashMap<Long,Hop> _aggregates;
public final HashSet<Long> _inputAggs = new HashSet<>();
public final HashSet<Long> _fusedInputs = new HashSet<>();
public AggregateInfo(Hop aggregate) {
_aggregates = new HashMap<>();
_aggregates.put(aggregate.getHopID(), aggregate);
}
public void addInputAggregate(long hopID) {
_inputAggs.add(hopID);
}
public void addFusedInput(long hopID) {
_fusedInputs.add(hopID);
}
public boolean isMergable(AggregateInfo that) {
//check independence
boolean ret = _aggregates.size()<3
&& _aggregates.size()+that._aggregates.size()<=3;
for( Long hopID : that._aggregates.keySet() )
ret &= !_inputAggs.contains(hopID);
for( Long hopID : _aggregates.keySet() )
ret &= !that._inputAggs.contains(hopID);
//check partial shared reads
ret &= CollectionUtils.containsAny(_fusedInputs, that._fusedInputs);
//check consistent sizes (result correctness)
Hop in1 = _aggregates.values().iterator().next();
Hop in2 = that._aggregates.values().iterator().next();
return ret && HopRewriteUtils.isEqualSize(
in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0),
in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0));
}
public AggregateInfo merge(AggregateInfo that) {
_aggregates.putAll(that._aggregates);
_inputAggs.addAll(that._inputAggs);
_fusedInputs.addAll(that._fusedInputs);
return this;
}
@Override
public String toString() {
return "["+Arrays.toString(_aggregates.keySet().toArray(new Long[0]))+": "
+"{"+Arrays.toString(_inputAggs.toArray(new Long[0]))+"},"
+"{"+Arrays.toString(_fusedInputs.toArray(new Long[0]))+"}]";
}
}
}