blob: b349a5f175e69c1c0c9a63dfd15eb5c68057ac67 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.ipa;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import java.util.Set;
/**
* Auxiliary data structure to hold function call summaries in terms
* of information about number of function calls, consistent dimensions,
* consistent sparsity, and dimension-preserving functions.
*
*/
public class FunctionCallSizeInfo
{
//basic function call graph to obtain size information
private final FunctionCallGraph _fgraph;
//functions that are subject to size propagation
//(called once or multiple times with consistent sizes)
private final Set<String> _fcand;
//functions that are not subject to size propagation
//but preserve the dimensions (used to propagate inputs
//to subsequent statement blocks and functions)
private final Set<String> _fcandUnary;
//indicators for which function arguments of valid functions it
//is safe to propagate the number of non-zeros
//(mapping from function keys to set of function input positions)
private final Map<String, Set<Integer>> _fcandSafeNNZ;
//indicators which literal function arguments can be safely
//propagated into and replaced in the respective functions
//(mapping from function keys to set of function input positions)
private final Map<String, Set<Integer>> _fSafeLiterals;
/**
* Constructs the function call summary for all functions
* reachable from the main program.
*
* @param fgraph function call graph
*/
public FunctionCallSizeInfo(FunctionCallGraph fgraph) {
this(fgraph, true);
}
/**
* Constructs the function call summary for all functions
* reachable from the main program.
*
* @param fgraph function call graph
* @param init initialize function candidates
*/
public FunctionCallSizeInfo(FunctionCallGraph fgraph, boolean init) {
_fgraph = fgraph;
_fcand = new HashSet<>();
_fcandUnary = new HashSet<>();
_fcandSafeNNZ = new HashMap<>();
_fSafeLiterals = new HashMap<>();
constructFunctionCallSizeInfo();
}
/**
* Gets the number of function calls to a given function.
*
* @param fkey function key
* @return number of function calls
*/
public int getFunctionCallCount(String fkey) {
return _fgraph.getFunctionCalls(fkey).size();
}
/**
* Indicates if the given function is valid for statistics
* propagation.
*
* @param fkey function key
* @return true if valid
*/
public boolean isValidFunction(String fkey) {
return _fcand.contains(fkey);
}
/**
* Gets the set of functions that are valid for statistics
* propagation.
*
* @return set of function keys
*/
public Set<String> getValidFunctions() {
return _fcand;
}
/**
* Gets the set of functions that are invalid for statistics
* propagation. This is literally the set of reachable
* functions minus the set of valid functions.
*
* @return set of function keys.
*/
public Set<String> getInvalidFunctions() {
return _fgraph.getReachableFunctions(getValidFunctions());
}
/**
* Adds a function to the set of dimension-preserving
* functions.
*
* @param fkey function key
*/
public void addDimsPreservingFunction(String fkey) {
_fcandUnary.add(fkey);
}
/**
* Gets the set of dimension-preserving functions, i.e.,
* functions with one matrix input and output of equal
* dimension sizes.
*
* @return set of function keys
*/
public Set<String> getDimsPreservingFunctions() {
return _fcandUnary;
}
/**
* Indicates if the given function belongs to the set
* of dimension-preserving functions.
*
* @param fkey function key
* @return true if the function is dimension-preserving
*/
public boolean isDimsPreservingFunction(String fkey) {
return _fcandUnary.contains(fkey);
}
/**
* Indicates if the given function input allows for safe
* nnz propagation, i.e., all function calls have a consistent
* number of non-zeros.
*
* @param fkey function key
* @param pos function input position
* @return true if nnz can safely be propagated
*/
public boolean isSafeNnz(String fkey, int pos) {
return _fcandSafeNNZ.containsKey(fkey)
&& _fcandSafeNNZ.get(fkey).contains(pos);
}
/**
* Indicates if the given function has at least one input
* that allows for safe literal propagation and replacement,
* i.e., all function calls have consistent literal inputs.
*
* @param fkey function key
* @return true if a literal can be safely propagated
*/
public boolean hasSafeLiterals(String fkey) {
return _fSafeLiterals.containsKey(fkey)
&& !_fSafeLiterals.get(fkey).isEmpty();
}
/**
* Indicates if the given function input allows for safe
* literal propagation and replacement, i.e., all function calls
* have consistent literal inputs.
*
* @param fkey function key
* @param pos function input position
* @return true if literal that can be safely propagated
*/
public boolean isSafeLiteral(String fkey, int pos) {
return _fSafeLiterals.containsKey(fkey)
&& _fSafeLiterals.get(fkey).contains(pos);
}
private void constructFunctionCallSizeInfo()
{
//step 1: determine function candidates by evaluating all function calls
for( String fkey : _fgraph.getReachableFunctions() ) {
List<FunctionOp> flist = _fgraph.getFunctionCalls(fkey);
if( flist == null || flist.isEmpty() ) //robustness removed functions
continue;
//condition 1: function called just once
if( flist.size() == 1 ) {
_fcand.add(fkey);
}
//condition 2: check for consistent input sizes
else if( InterProceduralAnalysis.ALLOW_MULTIPLE_FUNCTION_CALLS ) {
//compare input matrix characteristics of first against all other calls
FunctionOp first = flist.get(0);
boolean consistent = true;
for( int i=1; i<flist.size(); i++ ) {
FunctionOp other = flist.get(i);
for( int j=0; j<first.getInput().size(); j++ ) {
Hop h1 = first.getInput().get(j);
Hop h2 = other.getInput().get(j);
//check matrix and scalar sizes (if known dims, nnz known/unknown,
// safeness of nnz propagation, determined later per input)
consistent &= (h1.dimsKnown() && h2.dimsKnown()
&& h1.getDim1()==h2.getDim1()
&& h1.getDim2()==h2.getDim2()
&& h1.getNnz()==h2.getNnz() );
//check literal values (equi value)
if( h1 instanceof LiteralOp ) {
consistent &= (h2 instanceof LiteralOp
&& HopRewriteUtils.isEqualValue((LiteralOp)h1, (LiteralOp)h2));
}
else if(h2 instanceof LiteralOp) {
consistent = false; //h2 literal, but h1 not
}
}
}
if( consistent )
_fcand.add(fkey);
}
}
//step 2: determine safe nnz propagation per input
//(considered for valid functions only)
for( String fkey : _fcand ) {
List<FunctionOp> flist = _fgraph.getFunctionCalls(fkey);
if( flist == null || flist.isEmpty() ) //robustness removed functions
continue;
FunctionOp first = flist.get(0);
HashSet<Integer> tmp = new HashSet<>();
for( int j=0; j<first.getInput().size(); j++ ) {
//if nnz known it is safe to propagate those nnz because for multiple calls
//we checked of equivalence and hence all calls have the same nnz
Hop input = first.getInput().get(0);
if( input.getNnz()>=0 )
tmp.add(j);
}
_fcandSafeNNZ.put(fkey, tmp);
}
//step 3: determine safe literal replacement per function input
//(considered for all functions)
for( String fkey : _fgraph.getReachableFunctions() ) {
List<FunctionOp> flist = _fgraph.getFunctionCalls(fkey);
if( flist == null || flist.isEmpty() ) //robustness removed functions
continue;
FunctionOp first = flist.get(0);
//initialize w/ all literals of first call
HashSet<Integer> tmp = new HashSet<>();
for( int j=0; j<first.getInput().size(); j++ )
if( first.getInput().get(j) instanceof LiteralOp )
tmp.add(j);
//check consistency across all function calls
for( int i=1; i<flist.size(); i++ ) {
FunctionOp other = flist.get(i);
for( int j=0; j<first.getInput().size(); j++ )
if( tmp.contains(j) ) {
Hop h1 = first.getInput().get(j);
Hop h2 = other.getInput().get(j);
if( !(h2 instanceof LiteralOp && HopRewriteUtils
.isEqualValue((LiteralOp)h1, (LiteralOp)h2)) )
tmp.remove(j);
}
}
_fSafeLiterals.put(fkey, tmp);
}
}
@Override
public int hashCode() {
return Arrays.hashCode(new int[] {
_fgraph.hashCode(),
_fcand.hashCode(),
_fcandUnary.hashCode(),
_fcandSafeNNZ.hashCode(),
_fSafeLiterals.hashCode()
});
}
@Override
public boolean equals(Object o) {
if( o instanceof FunctionCallSizeInfo )
return false;
FunctionCallSizeInfo that = (FunctionCallSizeInfo)o;
return _fgraph == that._fgraph
&& _fcand.equals(that._fcand)
&& _fcandUnary.equals(that._fcandUnary)
&& _fcandSafeNNZ.entrySet().equals(that._fcandSafeNNZ.entrySet())
&& _fSafeLiterals.entrySet().equals(that._fSafeLiterals.entrySet());
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Valid functions for propagation: \n");
for( String fkey : getValidFunctions() ) {
sb.append("--");
sb.append(fkey);
sb.append(": ");
sb.append(getFunctionCallCount(fkey));
if( !_fcandSafeNNZ.get(fkey).isEmpty() ) {
sb.append("\n----");
sb.append(Arrays.toString(_fcandSafeNNZ.get(fkey).toArray(new Integer[0])));
}
sb.append("\n");
}
if( !getInvalidFunctions().isEmpty() ) {
sb.append("Invalid functions for propagation: \n");
for( String fkey : getInvalidFunctions() ) {
sb.append("--");
sb.append(fkey);
sb.append(": ");
sb.append(getFunctionCallCount(fkey));
sb.append("\n");
}
}
if( !getDimsPreservingFunctions().isEmpty() ) {
sb.append("Dimensions-preserving functions: \n");
for( String fkey : getDimsPreservingFunctions() ) {
sb.append("--");
sb.append(fkey);
sb.append(": ");
sb.append(getFunctionCallCount(fkey));
sb.append("\n");
}
}
sb.append("Valid scalars for propagation: \n");
for( Entry<String, Set<Integer>> e : _fSafeLiterals.entrySet() ) {
sb.append("--");
sb.append(e.getKey());
sb.append(": ");
for( Integer pos : e.getValue() ) {
sb.append(pos);
sb.append(":");
sb.append(_fgraph.getFunctionCalls(e.getKey())
.get(0).getInput().get(pos).getName());
sb.append(" ");
}
sb.append("\n");
}
sb.append("Valid #non-zeros for propagation: \n");
for( Entry<String, Set<Integer>> e : _fcandSafeNNZ.entrySet() ) {
sb.append("--");
sb.append(e.getKey());
sb.append(": ");
for( Integer pos : e.getValue() ) {
sb.append(pos);
sb.append(":");
sb.append(_fgraph.getFunctionCalls(e.getKey())
.get(0).getInput().get(pos).getName());
sb.append(" ");
}
sb.append("\n");
}
return sb.toString();
}
}