[SYSTEMDS-3566] Heuristic-based operator placement policy for GPU

This patch adds a few rules to move GPU operators to CP. Examples include
sparse operation and GPU operator sandwiched between CP operators.
This policy is implemented as a Lop rewrite.

Closes #1837
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 3086ada..62352bd 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -24,6 +24,7 @@
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.mapred.JobConf;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.conf.CompilerConfig.ConfigType;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.Compression.CompressConfig;
@@ -279,6 +280,12 @@
 			|| OptimizerUtils.ASYNC_CHECKPOINT_SPARK);
 	}
 
+	public static boolean isRuleBasedGPUPlacement() {
+		return (DMLScript.USE_ACCELERATOR &&
+			(getDMLConfig().getBooleanValue(DMLConfig.GPU_RULE_BASED_PLACEMENT)
+			|| OptimizerUtils.RULE_BASED_GPU_EXEC));
+	}
+
 	public static ILinearize.DagLinearization getLinearizationOrder() {
 		if (OptimizerUtils.COST_BASED_ORDERING)
 			return ILinearize.DagLinearization.AUTO;
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index 46580c2..1b9cf2c 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -101,6 +101,7 @@
 	public static final String SYNCHRONIZE_GPU      = "sysds.gpu.sync.postProcess"; // boolean: whether to synchronize GPUs after every instruction
 	public static final String EAGER_CUDA_FREE      = "sysds.gpu.eager.cudaFree"; // boolean: whether to perform eager CUDA free on rmvar
 	public static final String GPU_EVICTION_POLICY  = "sysds.gpu.eviction.policy"; // string: can be lru, lfu, min_evict
+	public static final String GPU_RULE_BASED_PLACEMENT = "sysds.gpu.place.rulebased"; // boolean: apply rule-based operator placement for GPU
 	public static final String USE_LOCAL_SPARK_CONFIG = "sysds.local.spark"; // If set to true, it forces spark execution to a local spark context.
 	public static final String LOCAL_SPARK_NUM_THREADS = "sysds.local.spark.number.threads"; // the number of threads allowed to be used in the local spark configuration, default is * to enable use of all threads.
 	public static final String LINEAGECACHESPILL    = "sysds.lineage.cachespill"; // boolean: whether to spill cache entries to disk
@@ -193,6 +194,7 @@
 		_defaultVals.put(LOCAL_SPARK_NUM_THREADS, "*"); // * Means it allocates the number of available threads on the local host machine.
 		_defaultVals.put(SYNCHRONIZE_GPU,        "false" );
 		_defaultVals.put(EAGER_CUDA_FREE,        "false" );
+		_defaultVals.put(GPU_RULE_BASED_PLACEMENT, "false");
 		_defaultVals.put(FLOATING_POINT_PRECISION, "double" );
 		_defaultVals.put(USE_SSL_FEDERATED_COMMUNICATION, "false");
 		_defaultVals.put(DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT, "10");
@@ -455,11 +457,11 @@
 			COMPRESSED_COCODE, COMPRESSED_TRANSPOSE, COMPRESSED_TRANSFORMENCODE, DAG_LINEARIZATION,
 			CODEGEN, CODEGEN_API, CODEGEN_COMPILER, CODEGEN_OPTIMIZER, CODEGEN_PLANCACHE, CODEGEN_LITERALS,
 			STATS_MAX_WRAP_LEN, LINEAGECACHESPILL, COMPILERASSISTED_RW, BUFFERPOOL_LIMIT, MEMORY_MANAGER,
-			PRINT_GPU_MEMORY_INFO, AVAILABLE_GPUS, SYNCHRONIZE_GPU, EAGER_CUDA_FREE, FLOATING_POINT_PRECISION,
-			GPU_EVICTION_POLICY, LOCAL_SPARK_NUM_THREADS, EVICTION_SHADOW_BUFFERSIZE, GPU_MEMORY_ALLOCATOR,
-			GPU_MEMORY_UTILIZATION_FACTOR, USE_SSL_FEDERATED_COMMUNICATION, DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT,
-			FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY, ASYNC_SPARK_PREFETCH, ASYNC_SPARK_BROADCAST,
-			ASYNC_SPARK_CHECKPOINT
+			PRINT_GPU_MEMORY_INFO, AVAILABLE_GPUS, SYNCHRONIZE_GPU, EAGER_CUDA_FREE, GPU_RULE_BASED_PLACEMENT,
+			FLOATING_POINT_PRECISION, GPU_EVICTION_POLICY, LOCAL_SPARK_NUM_THREADS, EVICTION_SHADOW_BUFFERSIZE,
+			GPU_MEMORY_ALLOCATOR, GPU_MEMORY_UTILIZATION_FACTOR, USE_SSL_FEDERATED_COMMUNICATION,
+			DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT, FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY,
+			ASYNC_SPARK_PREFETCH, ASYNC_SPARK_BROADCAST, ASYNC_SPARK_CHECKPOINT
 		}; 
 		
 		StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 231435f..5020067 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -143,7 +143,8 @@
 			      input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput );
 		switch( _method ){
 			case TSMM: 
-				return false; // TODO: Disabling any fused transa optimization in 1.0 release. 
+				//return false; // TODO: Disabling any fused transa optimization in 1.0 release.
+				return true;
 			case MAPMM_CHAIN:
 				return false;
 			case PMM:
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 19cd080..1a7cc1c 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -305,6 +305,11 @@
 	 */
 	public static boolean COST_BASED_ORDERING = false;
 
+	/**
+	 * Rule-based operator placement policy for GPU.
+	 */
+	public static boolean RULE_BASED_GPU_EXEC = false;
+
 	//////////////////////
 	// Optimizer levels //
 	//////////////////////
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java
index 3d84348..5f32650 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -306,6 +306,10 @@
 		return inputs;
 	}
 
+	public Lop getInput(int index) {
+		return inputs.get(index);
+	}
+
 	/**
 	 * Method to get output of Lops
 	 * 
@@ -503,10 +507,19 @@
  		lps.setExecType(newExecType);
 	}
 
+
 	public boolean isExecSpark () {
 		return (lps.getExecType() == ExecType.SPARK);
 	}
 
+	public boolean isExecGPU () {
+		return (lps.getExecType() == ExecType.GPU);
+	}
+
+	public boolean isExecCP () {
+		return (lps.getExecType() == ExecType.CP);
+	}
+
 	public boolean getProducesIntermediateOutput() {
 		return lps.getProducesIntermediateOutput();
 	}
@@ -531,7 +544,19 @@
 	public OutputParameters getOutputParameters() {
 		return outParams;
 	}
-	
+
+	public long getNumRows() {
+		return getOutputParameters().getNumRows();
+	}
+
+	public long getNumCols() {
+		return getOutputParameters().getNumCols();
+	}
+
+	public long getNnz() {
+		return getOutputParameters().getNnz();
+	}
+
 	/**
 	 * Method to get aggregate type if applicable.
 	 * This method is overridden by the Lops with aggregate types (e.g. MapMult)
@@ -739,6 +764,25 @@
 	}
 
 	/**
+	 * Function that determines if all the outputs of a LOP are of GPU execution types
+	 *
+	 * @return true if all outputs are CP
+	 */
+	public boolean isAllOutputsGPU() {
+		if (outputs.isEmpty())
+			return false;
+
+		boolean outGPU = true;
+		for (Lop out : getOutputs()) {
+			if (out.getExecType() != ExecType.GPU) {
+				outGPU = false;
+				break;
+			}
+		}
+		return outGPU;
+	}
+
+	/**
 	 * Method to prepare instruction operand with given parameters.
 	 * 
 	 * @param label instruction label
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
index 0457558..55b5905 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -44,6 +44,7 @@
 		_lopSBRuleSet.add(new RewriteAddBroadcastLop());
 		_lopSBRuleSet.add(new RewriteAddChkpointLop());
 		_lopSBRuleSet.add(new RewriteAddChkpointInLoop());
+		_lopSBRuleSet.add(new RewriteUpdateGPUPlacements());
 		// TODO: A rewrite pass to remove less effective chkpoints
 		// Last rewrite to reset Lop IDs in a depth-first manner
 		_lopSBRuleSet.add(new RewriteFixIDs());
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/RewriteUpdateGPUPlacements.java b/src/main/java/org/apache/sysds/lops/rewrite/RewriteUpdateGPUPlacements.java
new file mode 100644
index 0000000..6bdba60
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteUpdateGPUPlacements.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.lops.Data;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.MatMultCP;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.runtime.matrix.data.LibMatrixNative;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class RewriteUpdateGPUPlacements extends LopRewriteRule
+{
+	@Override
+	public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb)
+	{
+		// Return if rule-based GPU placement is disabled
+		if (!ConfigurationManager.isRuleBasedGPUPlacement())
+			return List.of(sb);
+
+		// Return if all operators are CP
+		ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
+		if (lops == null || lops.stream().noneMatch(Lop::isExecGPU))
+			return List.of(sb);
+
+		// Iterate the DAGs and apply the rules on the GPU operators
+		// TODO: Iterate multiple times to propagate the updates
+		List<Lop> roots = sb.getLops();
+		roots.forEach(this::rUpdateExecType);
+		roots.forEach(Lop::resetVisitStatus);
+
+		return List.of(sb);
+	}
+
+	@Override
+	public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+		return sbs;
+	}
+
+	private void updateExecTypeGPU2CP(Lop lop) {
+		// Return if not a GPU op
+		if (!lop.isExecGPU())
+			return;
+
+		// Rule1: Place only dense operators at GPU (no sparse inputs)
+		// Ignore this check if dimensions and nnz are unknown
+		for (Lop in : lop.getInputs()) {
+			if (in.getNnz() >= 0
+				&& MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), in.getNumCols(), in.getNnz())) {
+				// Sparse input. Change to CP. This also avoids s2d and d2s conversions.
+				lop.setExecType(Types.ExecType.CP);
+				return;
+			}
+		}
+
+		// Rule2: Place compute-intensive MatMults at GPU regardless inputs' locations
+		if (lop instanceof MatMultCP) {
+			boolean memBound = LibMatrixNative.isMatMultMemoryBound((int) lop.getInput(0).getNumRows(),
+				(int) lop.getInput(0).getNumCols(), (int) lop.getInput(1).getNumCols());
+			if (!memBound) // Compute bound. Stays at GPU
+				return;
+		}
+
+		// Rule3: Location aware placement
+		// TODO: Propagate GPU execution types to DataOps (in hop level or lop level).
+		//  For now, skip this rule if the input is a DataOp.
+		if (lop.getInputs().size() == 2) { //binary operator
+			// Estimate sizes
+			long size1 = MatrixBlock.estimateSizeInMemory(lop.getInput(0).getNumRows(),
+				lop.getInput(0).getNumCols(), lop.getInput(0).getNnz());
+			long size2 = MatrixBlock.estimateSizeInMemory(lop.getInput(1).getNumRows(),
+				lop.getInput(1).getNumCols(), lop.getInput(1).getNnz());
+			// Move to CP if the larger input and all output intermediates are at host
+			if (size1 > size2 && !((lop.getInput(0)) instanceof Data)
+				&& !lop.getInput(0).isExecGPU() && !lop.isAllOutputsGPU())
+				lop.setExecType(Types.ExecType.CP);
+			if (size2 > size1 && !((lop.getInput(1)) instanceof Data)
+				&& !lop.getInput(1).isExecGPU() && !lop.isAllOutputsGPU())
+				lop.setExecType(Types.ExecType.CP);
+			// If same sized, move to CP if both the inputs and outputs are CP
+			if (size1 == size2 &&!(lop.getInput(0) instanceof Data)
+				&& !(lop.getInput(1) instanceof Data) && !lop.getInput(0).isExecGPU()
+				&& !lop.getInput(1).isExecGPU() && !lop.isAllOutputsGPU())
+				lop.setExecType(Types.ExecType.CP);
+		}
+
+		// For unary, move to CP if the input and the outputs are CP
+		if (lop.getInputs().size() == 1)
+			if (!(lop.getInput(0) instanceof Data)
+				&& !lop.getInput(0).isExecGPU()
+				&& !lop.isAllOutputsGPU())
+				lop.setExecType(Types.ExecType.CP);
+
+		// For ternary, move to CP if most inputs and outputs are CP
+		if (lop.getInputs().size() > 2) {
+			int numGPUInputs = 0;
+			int numCPInputs = 0;
+			for (Lop in : lop.getInputs()) {
+				if (!(in instanceof Data) && in.isExecGPU())
+					numGPUInputs++;
+				if (!(in instanceof Data) && in.isExecCP())
+					numCPInputs++;
+			}
+			if (numCPInputs > numGPUInputs && !lop.isAllOutputsGPU())
+				lop.setExecType(Types.ExecType.CP);
+		}
+	}
+
+	private void rUpdateExecType(Lop root) {
+		if (root.isVisited())
+			return;
+
+		for (Lop input : root.getInputs()) {
+			if (input instanceof Data)
+				continue;
+			rUpdateExecType(input);
+		}
+		updateExecTypeGPU2CP(root);
+		root.setVisited();
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index f016235..dcc1707 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2568,6 +2568,10 @@
 		return estimateSizeInMemory(dc.getRows(), dc.getCols(), dc.getSparsity());
 	}
 
+	public static long estimateSizeInMemory(long nrows, long ncols, long nnz) {
+		return estimateSizeInMemory(nrows, ncols, OptimizerUtils.getSparsity(nrows, ncols, nnz));
+	}
+
 	public long estimateSizeDenseInMemory() {
 		return estimateSizeDenseInMemory(rlen, clen);
 	}