[SYSTEMDS-1780] Improved cost estimator for resource elasticity
Closes #2108.
diff --git a/scripts/perftest/resource/test_ops.dml b/scripts/perftest/resource/test_ops.dml
new file mode 100644
index 0000000..30d5a68
--- /dev/null
+++ b/scripts/perftest/resource/test_ops.dml
@@ -0,0 +1,13 @@
+
+X = read($X);
+Y = read($Y);
+Z = read($Z);
+
+A = X%*%Y;
+B = A + Z;
+C = B[1:1000,1:1000];
+
+print(nrow(A));
+print(nrow(B));
+print(nrow(C));
+
diff --git a/src/main/java/org/apache/sysds/resource/CloudUtils.java b/src/main/java/org/apache/sysds/resource/CloudUtils.java
index da09b80..c2ffe72 100644
--- a/src/main/java/org/apache/sysds/resource/CloudUtils.java
+++ b/src/main/java/org/apache/sysds/resource/CloudUtils.java
@@ -51,8 +51,8 @@
}
}
+ public static final String SPARK_VERSION = "3.3.0";
public static final double MINIMAL_EXECUTION_TIME = 120; // seconds; NOTE: set always equal or higher than DEFAULT_CLUSTER_LAUNCH_TIME
-
public static final double DEFAULT_CLUSTER_LAUNCH_TIME = 120; // seconds; NOTE: set always to at least 60 seconds
public static long GBtoBytes(double gb) {
diff --git a/src/main/java/org/apache/sysds/resource/ResourceCompiler.java b/src/main/java/org/apache/sysds/resource/ResourceCompiler.java
index ddf486c..4ddc381 100644
--- a/src/main/java/org/apache/sysds/resource/ResourceCompiler.java
+++ b/src/main/java/org/apache/sysds/resource/ResourceCompiler.java
@@ -55,11 +55,11 @@
public static final long DEFAULT_DRIVER_MEMORY = 512*1024*1024; // 0.5GB
public static final int DEFAULT_DRIVER_THREADS = 1; // 0.5GB
public static final long DEFAULT_EXECUTOR_MEMORY = 512*1024*1024; // 0.5GB
- public static final int DEFAULT_EXECUTOR_THREADS = 1; // 0.5GB
- public static final int DEFAULT_NUMBER_EXECUTORS = 1; // 0.5GB
+ public static final int DEFAULT_EXECUTOR_THREADS = 2; // avoids creating spark context
+ public static final int DEFAULT_NUMBER_EXECUTORS = 2; // avoids creating spark context
static {
// TODO: consider moving to the executable of the resource optimizer once implemented
- USE_LOCAL_SPARK_CONFIG = true;
+ // USE_LOCAL_SPARK_CONFIG = true; -> needs to be false to trigger evaluating the default parallelism
ConfigurationManager.getCompilerConfig().set(CompilerConfig.ConfigType.ALLOW_DYN_RECOMPILATION, false);
ConfigurationManager.getCompilerConfig().set(CompilerConfig.ConfigType.RESOURCE_OPTIMIZATION, true);
}
@@ -214,12 +214,26 @@
}
}
+ /**
+ * Sets resource configurations for the node executing the control program.
+ *
+ * @param nodeMemory memory in Bytes
+ * @param nodeNumCores number of CPU cores
+ */
public static void setDriverConfigurations(long nodeMemory, int nodeNumCores) {
- // TODO: think of reasonable factor for the JVM heap as prt of the node's memory
- InfrastructureAnalyzer.setLocalMaxMemory(nodeMemory);
+ // use 90% of the node's memory for the JVM heap -> rest needed for the OS
+ InfrastructureAnalyzer.setLocalMaxMemory((long) (0.9 * nodeMemory));
InfrastructureAnalyzer.setLocalPar(nodeNumCores);
}
+ /**
+ * Sets resource configurations for the cluster of nodes
+ * executing the Spark jobs.
+ *
+ * @param numExecutors number of nodes in cluster
+ * @param nodeMemory memory in Bytes per node
+ * @param nodeNumCores number of CPU cores per node
+ */
public static void setExecutorConfigurations(int numExecutors, long nodeMemory, int nodeNumCores) {
// TODO: think of reasonable factor for the JVM heap as prt of the node's memory
if (numExecutors > 0) {
@@ -235,6 +249,7 @@
sparkConf.set("spark.executor.memory", (nodeMemory/(1024*1024))+"m");
sparkConf.set("spark.executor.instances", Integer.toString(numExecutors));
sparkConf.set("spark.executor.cores", Integer.toString(nodeNumCores));
+ // not setting "spark.default.parallelism" on purpose -> allows re-initialization
// ------------------ Dynamic Configurations -------------------
SparkExecutionContext.initLocalSparkContext(sparkConf);
} else {
diff --git a/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java b/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java
new file mode 100644
index 0000000..6d46070
--- /dev/null
+++ b/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java
@@ -0,0 +1,917 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.resource.cost;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.*;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.*;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.CMOperator;
+
+import static org.apache.sysds.resource.cost.IOCostUtils.IOMetrics;
+import static org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
+
+public class CPCostUtils {
+ private static final long DEFAULT_NFLOP_NOOP = 10;
+ private static final long DEFAULT_NFLOP_CP = 1;
+ private static final long DEFAULT_NFLOP_TEXT_IO = 350;
+ private static final long DEFAULT_INFERRED_DIM = 1000000;
+
+ public static double getVariableInstTime(VariableCPInstruction inst, VarStats input, VarStats output, IOMetrics metrics) {
+ long nflop;
+ switch (inst.getOpcode()) {
+ case "write":
+ String fmtStr = inst.getInput3().getLiteral().getStringValue();
+ Types.FileFormat fmt = Types.FileFormat.safeValueOf(fmtStr);
+ long xwrite = fmt.isTextFormat() ? DEFAULT_NFLOP_TEXT_IO : DEFAULT_NFLOP_CP;
+ nflop = input.getCellsWithSparsity() * xwrite;
+ break;
+ case "cast_as_matrix":
+ case "cast_as_frame":
+ nflop = input.getCells();
+ break;
+ case "rmfilevar": case "attachfiletovar": case "setfilename":
+ throw new RuntimeException("Undefined behaviour for instruction with opcode: " + inst.getOpcode());
+ default:
+ // negligibly low number of FLOP (independent of variables' dimensions)
+ return 0;
+ }
+ // assignOutputMemoryStats() needed only for casts
+ return getCPUTime(nflop, metrics, output, input);
+ }
+
+ public static double getDataGenCPInstTime(UnaryCPInstruction inst, VarStats output, IOMetrics metrics) {
+ long nflop;
+ String opcode = inst.getOpcode();
+ if( inst instanceof DataGenCPInstruction) {
+ if (opcode.equals("rand") || opcode.equals("frame")) {
+ DataGenCPInstruction rinst = (DataGenCPInstruction) inst;
+ if( rinst.getMinValue() == 0.0 && rinst.getMaxValue() == 0.0 )
+ nflop = 0; // empty matrix
+ else if( rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue() ) // allocate, array fill
+ nflop = 8 * output.getCells();
+ else { // full rand
+ if (rinst.getSparsity() == 1.0)
+ nflop = 32 * output.getCells() + 8 * output.getCells(); // DENSE gen (incl allocate)
+ else if (rinst.getSparsity() < MatrixBlock.SPARSITY_TURN_POINT)
+ nflop = 3 * output.getCellsWithSparsity() + 24 * output.getCellsWithSparsity(); //SPARSE gen (incl allocate)
+ else
+ nflop = 2 * output.getCells() + 8 * output.getCells(); // DENSE gen (incl allocate)
+ }
+ } else if (opcode.equals(DataGen.SEQ_OPCODE)) {
+ nflop = DEFAULT_NFLOP_CP * output.getCells();
+ } else {
+ // DataGen.SAMPLE_OPCODE, DataGen.TIME_OPCODE,
+ throw new RuntimeException("Undefined behaviour for instruction with opcode: " + inst.getOpcode());
+ }
+ }
+ else if( inst instanceof StringInitCPInstruction) {
+ nflop = DEFAULT_NFLOP_CP * output.getCells();
+ } else {
+ throw new IllegalArgumentException("Method has been called with invalid instruction: " + inst);
+ }
+ return getCPUTime(nflop, metrics, output);
+ }
+
+ public static double getUnaryInstTime(UnaryCPInstruction inst, VarStats input, VarStats weights, VarStats output, IOMetrics metrics) {
+ if (inst instanceof UaggOuterChainCPInstruction || inst instanceof DnnCPInstruction) {
+ throw new RuntimeException("Time estimation for CP instruction of class " + inst.getClass().getName() + "not supported yet");
+ }
+ // CPType = Unary/Builtin
+ CPType instructionType = inst.getCPInstructionType();
+ String opcode = inst.getOpcode();
+
+ boolean includeWeights = false;
+ if (inst instanceof MMTSJCPInstruction) {
+ MMTSJ.MMTSJType type = ((MMTSJCPInstruction) inst).getMMTSJType();
+ opcode += type.isLeft() ? "_left" : "_right";
+ } else if (inst instanceof ReorgCPInstruction && opcode.equals("rsort")) {
+ if (inst.input2 != null) includeWeights = true;
+ } else if (inst instanceof QuantileSortCPInstruction) {
+ if (inst.input2 != null) {
+ opcode += "_wts";
+ includeWeights = true;
+ }
+ } else if (inst instanceof CentralMomentCPInstruction) {
+ CMOperator.AggregateOperationTypes opType = ((CMOperator) inst.getOperator()).getAggOpType();
+ opcode += "_" + opType.name().toLowerCase();
+ if (inst.input2 != null) {
+ includeWeights = true;
+ }
+ }
+ long nflop = getInstNFLOP(instructionType, opcode, output, input);
+ if (includeWeights)
+ return getCPUTime(nflop, metrics, output, input, weights);
+ return getCPUTime(nflop, metrics, output, input);
+ }
+
+ public static double getBinaryInstTime(BinaryCPInstruction inst, VarStats input1, VarStats input2, VarStats weights, VarStats output, IOMetrics metrics) {
+ // CPType = Binary/Builtin
+ CPType instructionType = inst.getCPInstructionType();
+ String opcode = inst.getOpcode();
+
+ boolean includeWeights = false;
+ if (inst instanceof CovarianceCPInstruction) { // cov
+ includeWeights = true;
+ } else if (inst instanceof QuantilePickCPInstruction) {
+ PickByCount.OperationTypes opType = ((QuantilePickCPInstruction) inst).getOperationType();
+ opcode += "_" + opType.name().toLowerCase();
+ } else if (inst instanceof AggregateBinaryCPInstruction) {
+ AggregateBinaryCPInstruction abinst = (AggregateBinaryCPInstruction) inst;
+ opcode += abinst.transposeLeft? "_tl": "";
+ opcode += abinst.transposeRight? "_tr": "";
+ }
+ long nflop = getInstNFLOP(instructionType, opcode, output, input1, input2);
+ if (includeWeights)
+ return getCPUTime(nflop, metrics, output, input1, input2, weights);
+ return getCPUTime(nflop, metrics, output, input1, input2);
+ }
+
+ public static double getComputationInstTime(ComputationCPInstruction inst, VarStats input1, VarStats input2, VarStats input3, VarStats input4, VarStats output, IOMetrics metrics) {
+ if (inst instanceof UnaryCPInstruction || inst instanceof BinaryCPInstruction) {
+ throw new RuntimeException("Instructions of type UnaryCPInstruction and BinaryCPInstruction are not handled by this method");
+ }
+ CPType instructionType = inst.getCPInstructionType();
+ String opcode = inst.getOpcode();
+
+ // CURRENTLY: 2 is the maximum number of needed input stats objects for NFLOP estimation
+ long nflop = getInstNFLOP(instructionType, opcode, output, input1, input2);
+ return getCPUTime(nflop, metrics, output, input1, input2, input3, input4);
+ }
+
+ public static double getBuiltinNaryInstTime(BuiltinNaryCPInstruction inst, VarStats[] inputs, VarStats output, IOMetrics metrics) {
+ CPType instructionType = inst.getCPInstructionType();
+ String opcode = inst.getOpcode();
+ long nflop;
+ if (inputs == null) {
+ nflop = getInstNFLOP(instructionType, opcode, output);
+ return getCPUTime(nflop, metrics, output);
+ }
+ nflop = getInstNFLOP(instructionType, opcode, output, inputs);
+ return getCPUTime(nflop, metrics, output, inputs);
+ }
+
+ public static double getParameterizedBuiltinInstTime(ParameterizedBuiltinCPInstruction inst, VarStats input, VarStats output, IOMetrics metrics) {
+ CPType instructionType = inst.getCPInstructionType();
+ String opcode = inst.getOpcode();
+ if (opcode.equals("rmempty")) {
+ String margin = inst.getParameterMap().get("margin");
+ opcode += "_" + margin;
+ } else if (opcode.equals("groupedagg")) {
+ CMOperator.AggregateOperationTypes opType = ((CMOperator) inst.getOperator()).getAggOpType();
+ opcode += "_" + opType.name().toLowerCase();
+ }
+ long nflop = getInstNFLOP(instructionType, opcode, output, input);
+ return getCPUTime(nflop, metrics, output, input);
+ }
+
+ public static double getMultiReturnBuiltinInstTime(MultiReturnBuiltinCPInstruction inst, VarStats input, VarStats[] outputs, IOMetrics metrics) {
+ CPType instructionType = inst.getCPInstructionType();
+ String opcode = inst.getOpcode();
+ long nflop = getInstNFLOP(instructionType, opcode, outputs[0], input);
+ double time = getCPUTime(nflop, metrics, outputs[0], input);
+ for (int i = 1; i < outputs.length; i++) {
+ time += IOCostUtils.getMemWriteTime(outputs[i], metrics);
+ }
+ return time;
+ }
+
+ // HELPERS
+ public static void assignOutputMemoryStats(CPInstruction inst, VarStats output, VarStats...inputs) {
+ CPType instType = inst.getCPInstructionType();
+ String opcode = inst.getOpcode();
+
+ if (inst instanceof MultiReturnBuiltinCPInstruction) {
+ boolean inferred = false;
+ for (VarStats current : inputs) {
+ if (!inferred && current.getCells() < 0) {
+ inferStats(instType, opcode, output, inputs);
+ inferred = true;
+ }
+ if (current.getCells() < 0) {
+ throw new RuntimeException("Operation of type MultiReturnBuiltin with opcode '" + opcode + "' has incomplete formula for inferring dimensions");
+ }
+ current.allocatedMemory = OptimizerUtils.estimateSizeExactSparsity(current.characteristics);
+ }
+ return;
+ } else if (output.getCells() < 0) {
+ inferStats(instType, opcode, output, inputs);
+ }
+ output.allocatedMemory = output.isScalar()? 1 : OptimizerUtils.estimateSizeExactSparsity(output.characteristics);
+ }
+
+ public static void inferStats(CPType instType, String opcode, VarStats output, VarStats...inputs) {
+ switch (instType) {
+ case Unary:
+ case Builtin:
+ copyMissingDim(output, inputs[0]);
+ break;
+ case AggregateUnary:
+ if (opcode.startsWith("uar")) {
+ copyMissingDim(output, inputs[0].getM(), 1);
+ } else if (opcode.startsWith("uac")) {
+ copyMissingDim(output, 1, inputs[0].getN());
+ } else {
+ copyMissingDim(output, 1, 1);
+ }
+ break;
+ case MatrixIndexing:
+ if (opcode.equals("rightIndex")) {
+ long rowLower = (inputs[2].varName.matches("\\d+") ? Long.parseLong(inputs[2].varName) : -1);
+ long rowUpper = (inputs[3].varName.matches("\\d+") ? Long.parseLong(inputs[3].varName) : -1);
+ long colLower = (inputs[4].varName.matches("\\d+") ? Long.parseLong(inputs[4].varName) : -1);
+ long colUpper = (inputs[5].varName.matches("\\d+") ? Long.parseLong(inputs[5].varName) : -1);
+
+ long rowRange;
+ {
+ if (rowLower > 0 && rowUpper > 0) rowRange = rowUpper - rowLower + 1;
+ else if (inputs[2].varName.equals(inputs[3].varName)) rowRange = 1;
+ else
+ rowRange = inputs[0].getM() > 0 ? inputs[0].getM() : DEFAULT_INFERRED_DIM;
+ }
+ long colRange;
+ {
+ if (colLower > 0 && colUpper > 0) colRange = colUpper - colLower + 1;
+ else if (inputs[4].varName.equals(inputs[5].varName)) colRange = 1;
+ else
+ colRange = inputs[0].getM() > 0 ? inputs[0].getN() : DEFAULT_INFERRED_DIM;
+ }
+ copyMissingDim(output, rowRange, colRange);
+ } else { // leftIndex
+ copyMissingDim(output, inputs[0]);
+ }
+ break;
+ case Reorg:
+ switch (opcode) {
+ case "r'":
+ copyMissingDim(output, inputs[0].getN(), inputs[0].getM());
+ break;
+ case "rev":
+ copyMissingDim(output, inputs[0]);
+ break;
+ case "rdiag":
+ if (inputs[0].getN() == 1) // diagV2M
+ copyMissingDim(output, inputs[0].getM(), inputs[0].getM());
+ else // diagM2V
+ copyMissingDim(output, inputs[0].getM(), 1);
+ break;
+ case "rsort":
+ boolean ixRet = Boolean.parseBoolean(inputs[1].varName);
+ if (ixRet)
+ copyMissingDim(output, inputs[0].getM(), 1);
+ else
+ copyMissingDim(output, inputs[0]);
+ break;
+ }
+ break;
+ case Binary:
+ // handle case of matrix-scalar op. with the matrix being the second operand
+ VarStats origin = inputs[0].isScalar()? inputs[1] : inputs[0];
+ copyMissingDim(output, origin);
+ break;
+ case AggregateBinary:
+ boolean transposeLeft = false;
+ boolean transposeRight = false;
+ if (inputs.length == 4) {
+ transposeLeft = inputs[2] != null && Boolean.parseBoolean(inputs[2].varName);
+ transposeRight = inputs[3] != null && Boolean.parseBoolean(inputs[3].varName);
+ }
+ if (transposeLeft && transposeRight)
+ copyMissingDim(output, inputs[0].getM(), inputs[1].getM());
+ else if (transposeLeft)
+ copyMissingDim(output, inputs[0].getM(), inputs[1].getN());
+ else if (transposeRight)
+ copyMissingDim(output, inputs[0].getN(), inputs[1].getN());
+ else
+ copyMissingDim(output, inputs[0].getN(), inputs[1].getM());
+ break;
+ case ParameterizedBuiltin:
+ if (opcode.equals("rmempty") || opcode.equals("replace")) {
+ copyMissingDim(output, inputs[0]);
+ } else if (opcode.equals("uppertri") || opcode.equals("lowertri")) {
+ copyMissingDim(output, inputs[0].getM(), inputs[0].getM());
+ }
+ break;
+ case Rand:
+ // inferring missing output dimensions is handled exceptionally here
+ if (output.getCells() < 0) {
+ long nrows = (inputs[0].varName.matches("\\d+") ? Long.parseLong(inputs[0].varName) : -1);
+ long ncols = (inputs[1].varName.matches("\\d+") ? Long.parseLong(inputs[1].varName) : -1);
+ copyMissingDim(output, nrows, ncols);
+ }
+ break;
+ case Ctable:
+ long m = (inputs[2].varName.matches("\\d+") ? Long.parseLong(inputs[2].varName) : -1);
+ long n = (inputs[3].varName.matches("\\d+") ? Long.parseLong(inputs[3].varName) : -1);
+ if (inputs[1].isScalar()) {// Histogram
+ if (m < 0) m = inputs[0].getM();
+ if (n < 0) n = 1;
+ copyMissingDim(output, m, n);
+ } else { // transform (including "ctableexpand")
+ if (m < 0) m = inputs[0].getM();
+ if (n < 0) n = inputs[1].getCells(); // NOTE: very generous assumption, it could be revised;
+ copyMissingDim(output, m, n);
+ }
+ break;
+ case MultiReturnBuiltin:
+ // special case: output and inputs stats arguments are swapped: always single input with multiple outputs
+ VarStats FirstStats = inputs[0];
+ VarStats SecondStats = inputs[1];
+ switch (opcode) {
+ case "qr":
+ copyMissingDim(FirstStats, output.getM(), output.getM()); // Q
+ copyMissingDim(SecondStats, output.getM(), output.getN()); // R
+ break;
+ case "lu":
+ copyMissingDim(FirstStats, output.getN(), output.getN()); // L
+ copyMissingDim(SecondStats, output.getN(), output.getN()); // U
+ break;
+ case "eigen":
+ copyMissingDim(FirstStats, output.getN(), 1); // values
+ copyMissingDim(SecondStats, output.getN(), output.getN()); // vectors
+ break;
+ // not all opcodes supported yet
+ }
+ break;
+ default:
+ throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has no formula for inferring dimensions");
+ }
+ if (output.getCells() < 0) {
+ throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has incomplete formula for inferring dimensions");
+ }
+ if (output.getNNZ() < 0) {
+ output.characteristics.setNonZeros(output.getCells());
+ }
+ }
+
+ private static void copyMissingDim(VarStats target, long originRows, long originCols) {
+ if (target.getM() < 0)
+ target.characteristics.setRows(originRows);
+ if (target.getN() < 0)
+ target.characteristics.setCols(originCols);
+ }
+
+ private static void copyMissingDim(VarStats target, VarStats origin) {
+ if (target.getM() < 0)
+ target.characteristics.setRows(origin.getM());
+ if (target.getN() < 0)
+ target.characteristics.setCols(origin.getN());
+ }
+
+ public static double getCPUTime(long nflop, IOCostUtils.IOMetrics driverMetrics, VarStats output, VarStats...inputs) {
+ double memScanTime = 0;
+ for (VarStats input: inputs) {
+ if (input == null) continue;
+ memScanTime += IOCostUtils.getMemReadTime(input, driverMetrics);
+ }
+ double cpuComputationTime = (double) nflop / driverMetrics.cpuFLOPS;
+ double memWriteTime = output != null? IOCostUtils.getMemWriteTime(output, driverMetrics) : 0;
+ return Math.max(memScanTime, cpuComputationTime) + memWriteTime;
+ }
+
+ /**
+ *
+ * @param instructionType instruction type
+ * @param opcode instruction opcode, potentially with suffix to mark an extra op. characteristic
+ * @param output output's variable statistics, null is not needed for the estimation
+ * @param inputs any inputs' variable statistics, no object passed is not needed for estimation
+ * @return estimated number of floating point operations
+ */
+ public static long getInstNFLOP(
+ CPType instructionType,
+ String opcode,
+ VarStats output,
+ VarStats...inputs
+ ) {
+ opcode = opcode.toLowerCase(); // enforce lowercase for convince
+ long m;
+ double costs = 0;
+ switch (instructionType) {
+ // types corresponding to UnaryCPInstruction
+ case Unary:
+ case Builtin: // log and log_nz only
+ if (output == null || inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for Unary/Builtin operations are passed initialized");
+ double sparsity = inputs[0].getSparsity();
+ switch (opcode) {
+ case "!":
+ case "isna":
+ case "isnan":
+ case "isinf":
+ case "ceil":
+ case "floor":
+ costs = 1;
+ break;
+ case "abs":
+ case "round":
+ case "sign":
+ costs = 1 * sparsity;
+ break;
+ case "sprop":
+ case "sqrt":
+ costs = 2 * sparsity;
+ break;
+ case "exp":
+ costs = 18 * sparsity;
+ break;
+ case "sigmoid":
+ costs = 21 * sparsity;
+ break;
+ case "log":
+ costs = 32;
+ break;
+ case "log_nz":
+ case "plogp":
+ costs = 32 * sparsity;
+ break;
+ case "print":
+ case "assert":
+ costs = 1;
+ break;
+ case "sin":
+ costs = 18 * sparsity;
+ break;
+ case "cos":
+ costs = 22 * inputs[0].getSparsity();
+ break;
+ case "tan":
+ costs = 42 * inputs[0].getSparsity();
+ break;
+ case "asin":
+ case "sinh":
+ costs = 93;
+ break;
+ case "acos":
+ case "cosh":
+ costs = 103;
+ break;
+ case "atan":
+ case "tanh":
+ costs = 40;
+ break;
+ case "ucumk+":
+ case "ucummin":
+ case "ucummax":
+ case "ucum*":
+ costs = 1 * sparsity;
+ break;
+ case "ucumk+*":
+ costs = 2 * sparsity;
+ break;
+ case "stop":
+ costs = 0;
+ break;
+ case "typeof":
+ costs = 1;
+ break;
+ case "inverse":
+ costs = (4.0 / 3.0) * output.getCellsWithSparsity() * output.getCellsWithSparsity();
+ break;
+ case "cholesky":
+ costs = (1.0 / 3.0) * output.getCellsWithSparsity() * output.getCellsWithSparsity();
+ break;
+ case "detectschema":
+ case "colnames":
+ throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet");
+ default:
+ // at the point of implementation no further supported operations
+ throw new DMLRuntimeException("Unary operation with opcode '" + opcode + "' is not supported by SystemDS");
+ }
+ return (long) (costs * output.getCells());
+ case AggregateUnary:
+ if (output == null || inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for AggregateUnary operations are passed initialized");
+ switch (opcode) {
+ case "nrow":
+ case "ncol":
+ case "length":
+ case "exists":
+ case "lineage":
+ return DEFAULT_NFLOP_NOOP;
+ case "uak+":
+ case "uark+":
+ case "uack+":
+ costs = 4;
+ break;
+ case "uasqk+":
+ case "uarsqk+":
+ case "uacsqk+":
+ costs = 5;
+ break;
+ case "uamean":
+ case "uarmean":
+ case "uacmean":
+ costs = 7;
+ break;
+ case "uavar":
+ case "uarvar":
+ case "uacvar":
+ costs = 14;
+ break;
+ case "uamax":
+ case "uarmax":
+ case "uarimax":
+ case "uacmax":
+ case "uamin":
+ case "uarmin":
+ case "uarimin":
+ case "uacmin":
+ costs = 1;
+ break;
+ case "ua+":
+ case "uar+":
+ case "uac+":
+ case "ua*":
+ case "uar*":
+ case "uac*":
+ costs = 1 * output.getSparsity();
+ break;
+ // count distinct operations
+ case "uacd":
+ case "uacdr":
+ case "uacdc":
+ case "unique":
+ case "uniquer":
+ case "uniquec":
+ costs = 1 * output.getSparsity();
+ break;
+ case "uacdap":
+ case "uacdapr":
+ case "uacdapc":
+ costs = 0.5 * output.getSparsity(); // do not iterate through all the cells
+ break;
+ // aggregation over the diagonal of a square matrix
+ case "uatrace":
+ case "uaktrace":
+ return inputs[0].getM();
+ default:
+ // at the point of implementation no further supported operations
+ throw new DMLRuntimeException("AggregateUnary operation with opcode '" + opcode + "' is not supported by SystemDS");
+ }
+ // scale
+ if (opcode.startsWith("uar")) {
+ costs *= inputs[0].getM();
+ } else if (opcode.startsWith("uac")) {
+ costs *= inputs[0].getN();
+ } else {
+ costs *= inputs[0].getCells();
+ }
+ return (long) (costs * output.getCells());
+ case MMTSJ:
+ if (inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for MMTSJ operations are passed initialized");
+ // reduce by factor of 4: matrix multiplication better than average FLOP count
+ // + multiply only upper triangular
+ if (opcode.equals("tsmm_left")) {
+ costs = inputs[0].getN() * (inputs[0].getSparsity() / 2);
+ } else { // tsmm/tsmm_right
+ costs = inputs[0].getM() * (inputs[0].getSparsity() / 2);
+ }
+ return (long) (costs * inputs[0].getCellsWithSparsity());
+ case Reorg:
+ case Reshape:
+ if (output == null)
+ throw new RuntimeException("Not all required arguments for Reorg/Reshape operations are passed initialized");
+ if (opcode.equals("rsort"))
+ return (long) (output.getCellsWithSparsity() * (Math.log(output.getM()) / Math.log(2))); // merge sort columns (n*m*log2(m))
+ return output.getCellsWithSparsity();
+ case MatrixIndexing:
+ if (output == null)
+ throw new RuntimeException("Not all required arguments for Indexing operations are passed initialized");
+ return output.getCellsWithSparsity();
+ case MMChain:
+ if (inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for MMChain operations are passed initialized");
+ // reduction by factor 2 because matrix mult better than average flop count
+ // (mmchain essentially two matrix-vector muliplications)
+ return (2 + 2) * inputs[0].getCellsWithSparsity() / 2;
+ case QSort:
+ if (inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for QSort operations are passed initialized");
+ // mergesort since comparator used
+ m = inputs[0].getM();
+ if (opcode.equals("qsort"))
+ costs = m + m;
+ else // == "qsort_wts" (with weights)
+ costs = m * inputs[0].getSparsity();
+ return (long) (costs + m * (int) (Math.log(m) / Math.log(2)) + m);
+ case CentralMoment:
+ if (inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for CentralMoment operations are passed initialized");
+ switch (opcode) {
+ case "cm_sum":
+ throw new RuntimeException("Undefined behaviour for CentralMoment operation of type sum");
+ case "cm_min":
+ case "cm_max":
+ case "cm_count":
+ costs = 2;
+ break;
+ case "cm_mean":
+ costs = 9;
+ break;
+ case "cm_variance":
+ case "cm_cm2":
+ costs = 17;
+ break;
+ case "cm_cm3":
+ costs = 32;
+ break;
+ case "cm_cm4":
+ costs = 52;
+ break;
+ case "cm_invalid":
+ // type INVALID used when unknown dimensions
+ throw new RuntimeException("CentralMoment operation of type INVALID is not supported");
+ default:
+ // at the point of implementation no further supported operations
+ throw new DMLRuntimeException("CentralMoment operation with type (<opcode>_<type>) '" + opcode + "' is not supported by SystemDS");
+ }
+ return (long) costs * inputs[0].getCellsWithSparsity();
+ case UaggOuterChain:
+ case Dnn:
+ throw new RuntimeException("CP operation type'" + instructionType + "' is not supported yet");
+ // types corresponding to BinaryCPInstruction
+ case Binary:
+ if (opcode.equals("+") || opcode.equals("-")) {
+ if (inputs.length < 2)
+ throw new RuntimeException("Not all required arguments for Binary operations +/- are passed initialized");
+ return inputs[0].getCellsWithSparsity() + inputs[1].getCellsWithSparsity();
+ } else if (opcode.equals("solve")) {
+ if (inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for Binary operation 'solve' are passed initialized");
+ return inputs[0].getCells() * inputs[0].getN();
+ }
+ if (output == null)
+ throw new RuntimeException("Not all required arguments for Binary operations are passed initialized");
+ switch (opcode) {
+ case "*":
+ case "^2":
+ case "*2":
+ case "max":
+ case "min":
+ case "-nz":
+ case "==":
+ case "!=":
+ case "<":
+ case ">":
+ case "<=":
+ case ">=":
+ case "&&":
+ case "||":
+ case "xor":
+ case "bitwand":
+ case "bitwor":
+ case "bitwxor":
+ case "bitwshiftl":
+ case "bitwshiftr":
+ costs = 1;
+ break;
+ case "%/%":
+ costs = 6;
+ break;
+ case "%%":
+ costs = 8;
+ break;
+ case "/":
+ costs = 22;
+ break;
+ case "log":
+ case "log_nz":
+ costs = 32;
+ break;
+ case "^":
+ costs = 16;
+ break;
+ case "1-*":
+ costs = 2;
+ break;
+ case "dropinvalidtype":
+ case "dropinvalidlength":
+ case "freplicate":
+ case "valueswap":
+ case "applyschema":
+ throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet");
+ default:
+ // at the point of implementation no further supported operations
+ throw new DMLRuntimeException("Binary operation with opcode '" + opcode + "' is not supported by SystemDS");
+ }
+ return (long) (costs * output.getCells());
+ case AggregateBinary:
+ if (output == null || inputs.length < 2)
+ throw new RuntimeException("Not all required arguments for AggregateBinary operations are passed initialized");
+ // costs represents the cost for matrix transpose
+ if (opcode.contains("_tl")) costs = inputs[0].getCellsWithSparsity();
+ if (opcode.contains("_tr")) costs = inputs[1].getCellsWithSparsity();
+ // else ba+*/pmm (or any of cpmm/rmm/mapmm from the Spark instructions)
+ // reduce by factor of 2: matrix multiplication better than average FLOP count: 2*m*n*p=m*n*p
+ return (long) (inputs[0].getN() * inputs[0].getSparsity()) * output.getCells() + (long) costs;
+ case Append:
+ if (inputs.length < 2)
+ throw new RuntimeException("Not all required arguments for Append operation is passed initialized");
+ return inputs[0].getCellsWithSparsity() * inputs[1].getCellsWithSparsity();
+ case Covariance:
+ if (inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for Covariance operation is passed initialized");
+ return (long) (23 * inputs[0].getM() * inputs[0].getSparsity());
+ case QPick:
+ switch (opcode) {
+ case "qpick_iqm":
+ m = inputs[0].getM();
+ return (long) (2 * m + //sum of weights
+ 5 * 0.25d * m + //scan to lower quantile
+ 8 * 0.5 * m); //scan from lower to upper quantile
+ case "qpick_median":
+ case "qpick_valuepick":
+ case "qpick_rangepick":
+ throw new RuntimeException("QuantilePickCPInstruction of operation type different from IQM is not supported yet");
+ default:
+ throw new DMLRuntimeException("QPick operation with opcode '" + opcode + "' is not supported by SystemDS");
+ }
+ // types corresponding to others CPInstruction(s)
+ case Ternary:
+ if (output == null)
+ throw new RuntimeException("Not all required arguments for Ternary operation is passed initialized");
+ switch (opcode) {
+ case "+*":
+ case "-*":
+ case "ifelse":
+ return 2 * output.getCells();
+ case "_map":
+ throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet");
+ default:
+ throw new DMLRuntimeException("Ternary operation with opcode '" + opcode + "' is not supported by SystemDS");
+ }
+ case AggregateTernary:
+ if (inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for AggregateTernary operation is passed initialized");
+ if (opcode.equals("tak+*") || opcode.equals("tack+*"))
+ return 6 * inputs[0].getCellsWithSparsity();
+ throw new DMLRuntimeException("AggregateTernary operation with opcode '" + opcode + "' is not supported by SystemDS");
+ case Quaternary:
+ //TODO pattern specific and all inputs required
+ if (inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for Quaternary operation is passed initialized");
+ if (opcode.equals("wsloss") || opcode.equals("wdivmm") || opcode.equals("wcemm")) {
+ // 4 matrices used
+ return 4 * inputs[0].getCells();
+ } else if (opcode.equals("wsigmoid") || opcode.equals("wumm")) {
+ // 3 matrices used
+ return 3 * inputs[0].getCells();
+ }
+ throw new DMLRuntimeException("Quaternary operation with opcode '" + opcode + "' is not supported by SystemDS");
+ case BuiltinNary:
+ if (output == null)
+ throw new RuntimeException("Not all required arguments for BuiltinNary operation is passed initialized");
+ switch (opcode) {
+ case "cbind":
+ case "rbind":
+ return output.getCellsWithSparsity();
+ case "nmin":
+ case "nmax":
+ case "n+":
+ return inputs.length * output.getCellsWithSparsity();
+ case "printf":
+ case "list":
+ return output.getN();
+ case "eval":
+ throw new RuntimeException("EvalNaryCPInstruction is not supported yet");
+ default:
+ throw new DMLRuntimeException("BuiltinNary operation with opcode '" + opcode + "' is not supported by SystemDS");
+ }
+ case Ctable:
+ if (output == null)
+ throw new RuntimeException("Not all required arguments for Ctable operation is passed initialized");
+ if (opcode.startsWith("ctable")) {
+ // potential high inaccuracy due to unknown output column size
+ // and inferring bound on number of elements what could lead to high underestimation
+ return 3 * output.getCellsWithSparsity();
+ }
+ throw new DMLRuntimeException("Ctable operation with opcode '" + opcode + "' is not supported by SystemDS");
+ case PMMJ:
+ // currently this would never be reached since the pmm instruction uses AggregateBinary op. type
+ if (output == null || inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for PMMJ operation is passed initialized");
+ if (opcode.equals("pmm")) {
+ return (long) (inputs[0].getN() * inputs[0].getSparsity()) * output.getCells();
+ }
+ throw new DMLRuntimeException("PMMJ operation with opcode '" + opcode + "' is not supported by SystemDS");
+ case ParameterizedBuiltin:
+ // no argument validation here since the logic is not fully defined for this operation
+ m = inputs[0].getM();
+ switch (opcode) {
+ case "contains":
+ case "replace":
+ case "tostring":
+ return inputs[0].getCells();
+ case "nvlist":
+ case "cdf":
+ case "invcdf":
+ case "lowertri":
+ case "uppertri":
+ case "rexpand":
+ return output.getCells();
+ case "rmempty_rows":
+ return (long) (inputs[0].getM() * Math.ceil(1.0d / inputs[0].getSparsity()) / 2)
+ + output.getCells();
+ case "rmempty_cols":
+ return (long) (inputs[0].getN() * Math.ceil(1.0d / inputs[0].getSparsity()) / 2)
+ + output.getCells();
+ // opcode: "groupedagg"
+ case "groupedagg_count":
+ case "groupedagg_min":
+ case "groupedagg_max":
+ return 2 * m + m;
+ case "groupedagg_sum":
+ return 2 * m + 4 * m;
+ case "groupedagg_mean":
+ return 2 * m + 8 * m;
+ case "groupedagg_cm2":
+ return 2 * m + 16 * m;
+ case "groupedagg_cm3":
+ return 2 * m + 31 * m;
+ case "groupedagg_cm4":
+ return 2 * m + 51 * m;
+ case "groupedagg_variance":
+ return 2 * m + 16 * m;
+ case "groupedagg_invalid":
+ // type INVALID used when unknown dimensions
+ throw new RuntimeException("ParameterizedBuiltin operation with opcode 'groupedagg' of type INVALID is not supported");
+ case "tokenize":
+ case "transformapply":
+ case "transformdecode":
+ case "transformcolmap":
+ case "transformmeta":
+ case "autodiff":
+ case "paramserv":
+ throw new RuntimeException("ParameterizedBuiltin operation with opcode '" + opcode + "' is not supported yet");
+ default:
+ throw new DMLRuntimeException("ParameterizedBuiltin operation with opcode '" + opcode + "' is not supported by SystemDS");
+ }
+ case MultiReturnBuiltin:
+ if (inputs.length < 1)
+ throw new RuntimeException("Not all required arguments for MultiReturnBuiltin operation is passed initialized");
+ switch (opcode) {
+ case "qr":
+ costs = 2;
+ break;
+ case "lu":
+ costs = 16;
+ break;
+ case "eigen":
+ case "svd":
+ costs = 32;
+ break;
+ case "fft":
+ case "fft_linearized":
+ throw new RuntimeException("MultiReturnBuiltin operation with opcode '" + opcode + "' is not supported yet");
+ default:
+ throw new DMLRuntimeException(" MultiReturnBuiltin operation with opcode '" + opcode + "' is not supported by SystemDS");
+ }
+ return (long) (costs * inputs[0].getCells() * inputs[0].getN());
+ case Prefetch:
+ case EvictLineageCache:
+ case Broadcast:
+ case Local:
+ case FCall:
+ case NoOp:
+ // not directly related to computation
+ return 0;
+ case Variable:
+ case Rand:
+ case StringInit:
+ throw new RuntimeException(instructionType + " instructions are not handled by this method");
+ case MultiReturnParameterizedBuiltin: // opcodes: transformencode
+ case MultiReturnComplexMatrixBuiltin: // opcodes: ifft, ifft_linearized, stft, rcm
+ case Compression: // opcode: compress
+ case DeCompression: // opcode: decompress
+ throw new RuntimeException("CP operation type'" + instructionType + "' is not supported yet");
+ case TrigRemote:
+ case Partition:
+ case SpoofFused:
+ case Sql:
+ throw new RuntimeException("CP operation type'" + instructionType + "' is not planned for support");
+ default:
+ // no further supported CP types
+ throw new DMLRuntimeException("CP operation type'" + instructionType + "' is not supported by SystemDS");
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
index 0c056e2..0f7fc5d 100644
--- a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
@@ -19,17 +19,14 @@
package org.apache.sysds.resource.cost;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.sysds.common.Types;
-import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.LeftIndex;
-import org.apache.sysds.lops.Lop;
-import org.apache.sysds.lops.MMTSJ;
-import org.apache.sysds.lops.RightIndex;
+import org.apache.sysds.lops.MapMult;
import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DataIdentifier;
+import org.apache.sysds.resource.CloudInstance;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.*;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
@@ -37,101 +34,101 @@
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.*;
import org.apache.sysds.runtime.instructions.spark.*;
-import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
-import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
-import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
-import static org.apache.sysds.common.Types.FileFormat.BINARY;
-import static org.apache.sysds.common.Types.FileFormat.TEXT;
import static org.apache.sysds.lops.Data.PREAD_PREFIX;
-import static org.apache.sysds.resource.cost.IOCostUtils.HDFS_SOURCE_IDENTIFIER;
-import static org.apache.sysds.resource.cost.IOCostUtils.S3_SOURCE_IDENTIFIER;
+import static org.apache.sysds.lops.DataGen.*;
+import static org.apache.sysds.resource.cost.IOCostUtils.*;
+import static org.apache.sysds.resource.cost.SparkCostUtils.getRandInstTime;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
+import java.util.*;
/**
* Class for estimating the execution time of a program.
* For estimating the time for new set of resources,
* a new instance of CostEstimator should be created.
- * TODO: consider reusing of some parts of the computation
- * for small changes in the resources
*/
public class CostEstimator
{
- protected static final Log LOG = LogFactory.getLog(CostEstimator.class.getName());
-
- private static final int DEFAULT_NUMITER = 15;
-
- //time-conversion
- private static final long DEFAULT_FLOPS = 2L * 1024 * 1024 * 1024; //2GFLOPS
- //private static final long UNKNOWN_TIME = -1;
-
- //floating point operations
- private static final double DEFAULT_NFLOP_NOOP = 10;
- private static final double DEFAULT_NFLOP_UNKNOWN = 1;
- private static final double DEFAULT_NFLOP_CP = 1;
- private static final double DEFAULT_NFLOP_TEXT_IO = 350;
-
- protected static long CP_FLOPS = DEFAULT_FLOPS;
- protected static long SP_FLOPS = DEFAULT_FLOPS;
- protected static final VarStats _unknownStats = new VarStats(new MatrixCharacteristics(-1,-1,-1,-1));
-
+ private static final long MIN_MEMORY_TO_TRACK = 1024 * 1024; // 1MB
+ private static final int DEFAULT_NUM_ITER = 15;
// Non-static members
- private SparkExecutionContext.MemoryManagerParRDDs _parRDDs;
- @SuppressWarnings("unused")
- private double[] cpCost; // (compute cost, I/O cost) for CP instructions
- @SuppressWarnings("unused")
- private double[] spCost; // (compute cost, I/O cost) for Spark instructions
-
+ private final Program _program;
+ private final IOCostUtils.IOMetrics driverMetrics;
+ private final IOCostUtils.IOMetrics executorMetrics;
// declare here the hashmaps
- protected HashMap<String, VarStats> _stats;
- // protected HashMap<String, RDDStats> _sparkStats;
- // protected HashMap<Integer, LinkedList<String>> _transformations;
- protected HashSet<String> _functions;
- private final long localMemory;
- private long usedMememory;
+ private final HashMap<String, VarStats> _stats;
+ private final HashSet<String> _functions;
+ private final long localMemoryLimit; // refers to the drivers JVM memory
+ private long freeLocalMemory;
/**
* Entry point for estimating the execution time of a program.
* @param program compiled runtime program
+ * @param driverNode ?
+ * @param executorNode ?
* @return estimated time for execution of the program
* given the resources set in {@link SparkExecutionContext}
* @throws CostEstimationException in case of errors
*/
- public static double estimateExecutionTime(Program program) throws CostEstimationException {
- CostEstimator estimator = new CostEstimator();
- double costs = estimator.getTimeEstimate(program);
- return costs;
+ public static double estimateExecutionTime(Program program, CloudInstance driverNode, CloudInstance executorNode) throws CostEstimationException {
+ CostEstimator estimator = new CostEstimator(program, driverNode, executorNode);
+ return estimator.getTimeEstimate();
}
- private CostEstimator() {
+ public CostEstimator(Program program, CloudInstance driverNode, CloudInstance executorNode) {
+ _program = program;
+ driverMetrics = new IOCostUtils.IOMetrics(driverNode);
+ executorMetrics = executorNode != null? new IOCostUtils.IOMetrics(executorNode) : null;
// initialize here the hashmaps
_stats = new HashMap<>();
- //_transformations = new HashMap<>();
_functions = new HashSet<>();
- localMemory = (long) OptimizerUtils.getLocalMemBudget();
- this._parRDDs = new SparkExecutionContext.MemoryManagerParRDDs(0.1);
- usedMememory = 0;
- cpCost = new double[]{0.0, 0.0};
- spCost = new double[]{0.0, 0.0};
+ localMemoryLimit = (long) OptimizerUtils.getLocalMemBudget();
+ freeLocalMemory = localMemoryLimit;
}
- public static void setCP_FLOPS(long gFlops) {
- CP_FLOPS = gFlops;
- }
- public static void setSP_FLOPS(long gFlops) {
- SP_FLOPS = gFlops;
+ /**
+ * Meant to be used for testing purposes
+ * @param inputStats ?
+ */
+ public void putStats(HashMap<String, VarStats> inputStats) {
+ _stats.putAll(inputStats);
}
- public double getTimeEstimate(Program rtprog) throws CostEstimationException {
+ /**
+ * Intended to be called only when it is certain that the corresponding
+ * variable is not a scalar and its statistics are in {@code _stats} already.
+ * @param statsName the corresponding operand name
+ * @return {@code VarStats object} if the given key is present
+ * in the map saving the current variable statistics.
+ * @throws RuntimeException if the corresponding variable is not in {@code _stats}
+ */
+ public VarStats getStats(String statsName) {
+ VarStats result = _stats.get(statsName);
+ if (result == null) {
+ throw new RuntimeException(statsName+" key not imported yet");
+ }
+ return result;
+ }
+
+ /**
+ * Intended to be called when the corresponding variable could be scalar.
+ * @param statsName the corresponding operand name
+ * @return {@code VarStats object} in any case
+ */
+ public VarStats getStatsWithDefaultScalar(String statsName) {
+ VarStats result = _stats.get(statsName);
+ if (result == null) {
+ result = new VarStats(statsName, null);
+ }
+ return result;
+ }
+
+ public double getTimeEstimate() throws CostEstimationException {
double costs = 0;
//get cost estimate
- for( ProgramBlock pb : rtprog.getProgramBlocks() )
+ for( ProgramBlock pb : _program.getProgramBlocks() )
costs += getTimeEstimatePB(pb);
return costs;
@@ -144,7 +141,7 @@
WhileProgramBlock tmp = (WhileProgramBlock)pb;
for (ProgramBlock pb2 : tmp.getChildBlocks())
ret += getTimeEstimatePB(pb2);
- ret *= DEFAULT_NUMITER;
+ ret *= DEFAULT_NUM_ITER;
}
else if (pb instanceof IfProgramBlock) {
IfProgramBlock tmp = (IfProgramBlock)pb; {
@@ -163,7 +160,7 @@
ret += getTimeEstimatePB(pb2);
// NOTE: currently ParFor blocks are handled as regular for block
// what could lead to very inaccurate estimation in case of complex ParFor blocks
- ret *= OptimizerUtils.getNumIterations(tmp, DEFAULT_NUMITER);
+ ret *= OptimizerUtils.getNumIterations(tmp, DEFAULT_NUM_ITER);
}
else if ( pb instanceof FunctionProgramBlock ) {
FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
@@ -177,846 +174,851 @@
for( Instruction inst : tmp )
{
- ret += getTimeEstimateInst(pb, inst);
- }
- }
- return ret;
- }
-
- private double getTimeEstimateInst(ProgramBlock pb, Instruction inst) throws CostEstimationException {
- double ret;
- if (inst instanceof CPInstruction) {
- maintainCPInstVariableStatistics((CPInstruction)inst);
-
- ret = getTimeEstimateCPInst(pb, (CPInstruction)inst);
-
- if( inst instanceof FunctionCallCPInstruction ) //functions
- {
- FunctionCallCPInstruction finst = (FunctionCallCPInstruction)inst;
- String fkey = DMLProgram.constructFunctionKey(finst.getNamespace(), finst.getFunctionName());
- //awareness of recursive functions, missing program
- if( !_functions.contains(fkey) && pb.getProgram()!=null )
+ if( inst instanceof FunctionCallCPInstruction ) //functions
{
- _functions.add(fkey);
- Program prog = pb.getProgram();
- FunctionProgramBlock fpb = prog.getFunctionProgramBlock(
- finst.getNamespace(), finst.getFunctionName());
- ret += getTimeEstimatePB(fpb);
- _functions.remove(fkey);
+ FunctionCallCPInstruction finst = (FunctionCallCPInstruction)inst;
+ String fkey = DMLProgram.constructFunctionKey(finst.getNamespace(), finst.getFunctionName());
+ //awareness of recursive functions, missing program
+ if( !_functions.contains(fkey) && pb.getProgram()!=null )
+ {
+ _functions.add(fkey);
+ maintainFCallInputStats(finst);
+ FunctionProgramBlock fpb = _program.getFunctionProgramBlock(fkey, true);
+ ret = getTimeEstimatePB(fpb);
+ maintainFCallOutputStats(finst, fpb);
+ _functions.remove(fkey);
+ }
+ } else {
+ maintainStats(inst);
+ ret += getTimeEstimateInst(inst);
}
}
- } else { // inst instanceof SPInstruction
- ret = 0; //dummy
}
return ret;
}
/**
- * Keep the variable statistics updated and compute I/O cost.
- * NOTE: At program execution reading the files is done once
- * the matrix is needed but cost estimation the place for
- * adding cost is not relevant.
- * @param inst
+ * Creates copies of the {@code VarStats} for the function argument.
+ * Meant to be called before estimating the execution time of
+ * the function program block of the corresponding function call instruction,
+ * otherwise the relevant statistics would not be available for the estimation.
+ * @param finst ?
*/
- private void maintainCPInstVariableStatistics(CPInstruction inst) throws CostEstimationException {
- if( inst instanceof VariableCPInstruction )
+ public void maintainFCallInputStats(FunctionCallCPInstruction finst) {
+ CPOperand[] inputs = finst.getInputs();
+ for (int i = 0; i < inputs.length; i++) {
+ DataType dt = inputs[i].getDataType();
+ if (dt == DataType.TENSOR) {
+ throw new DMLRuntimeException("Tensor is not supported for cost estimation");
+ } else if (dt == DataType.MATRIX || dt == DataType.FRAME || dt == DataType.LIST) {
+ String argName = finst.getFunArgNames().get(i);
+ VarStats argStats = getStats(inputs[i].getName());
+ if (inputs[i].getName().equals(argName)) {
+ if (argStats != _stats.get(argName))
+ throw new RuntimeException("Overriding referenced variable within a function call is not a handled case");
+ // reference duplication in different domain
+ argStats.selfRefCount++;
+ } else {
+ // passing the reference to another variable
+ argStats.refCount++;
+ _stats.put(finst.getFunArgNames().get(i), argStats);
+ }
+ }
+ // ignore scalars
+ }
+ }
+
+ /**
+ * Creates copies of the {@code VarStats} for the function output parameters.
+ * Meant to be called after estimating the execution time of
+ * the function program block of the corresponding function call instruction,
+ * otherwise the relevant statistics would not have been created yet.
+ * @param finst ?
+ * @param fpb ?
+ */
+ public void maintainFCallOutputStats(FunctionCallCPInstruction finst, FunctionProgramBlock fpb) {
+ List<DataIdentifier> params = fpb.getOutputParams();
+ List<String> boundNames = finst.getBoundOutputParamNames();
+ for(int i = 0; i < boundNames.size(); i++) {
+ // iterate through boundNames since it is possible to get the first output only
+ DataType dt = params.get(i).getDataType();
+ if (dt == DataType.TENSOR) {
+ throw new DMLRuntimeException("Tensor is not supported for cost estimation");
+ }
+ else if (dt == DataType.MATRIX || dt == DataType.FRAME || dt == DataType.LIST) {
+ VarStats boundStats = getStats(params.get(i).getName());
+ boundStats.refCount++;
+ _stats.put(boundNames.get(i), boundStats);
+ }
+ // ignore scalars
+ }
+ }
+
+ /**
+ * Keep the basic-block variable statistics updated and compute I/O cost.
+ * NOTE: At program execution reading the files is done once
+ * the matrix is needed but cost estimation the place for
+ * adding cost is not relevant.
+ * @param inst ?
+ */
+ public void maintainStats(Instruction inst) {
+ // CP Instructions changing the map for statistics
+ if(inst instanceof VariableCPInstruction)
{
String opcode = inst.getOpcode();
VariableCPInstruction vinst = (VariableCPInstruction) inst;
+ if (vinst.getInput1().getDataType() == DataType.TENSOR) {
+ throw new DMLRuntimeException("Tensor is not supported for cost estimation");
+ }
String varName = vinst.getInput1().getName();
- if( opcode.equals("createvar") ) {
- DataCharacteristics dataCharacteristics = vinst.getMetaData().getDataCharacteristics();
- VarStats varStats = new VarStats(dataCharacteristics);
- varStats._dirty = true;
- if (vinst.getInput1().getName().startsWith(PREAD_PREFIX)) {
- // NOTE: add I/O here although at execution the reading is done when the input is needed
- String fileName = vinst.getInput2().getName();
+ switch (opcode) {
+ case "createvar":
+ DataCharacteristics dataCharacteristics = vinst.getMetaData().getDataCharacteristics();
+ VarStats varStats = new VarStats(varName, dataCharacteristics);
+ if (vinst.getInput1().getName().startsWith(PREAD_PREFIX)) {
+ // NOTE: add I/O here although at execution the reading is done when the input is needed
+ String fileName = vinst.getInput2().getName();
+ String dataSource = IOCostUtils.getDataSource(fileName);
+ varStats.fileInfo = new Object[]{dataSource, ((MetaDataFormat) vinst.getMetaData()).getFileFormat()};
+ }
+ _stats.put(varName, varStats);
+ break;
+ case "cpvar":
+ VarStats outputStats = getStats(varName);
+ _stats.put(vinst.getInput2().getName(), outputStats);
+ outputStats.refCount++;
+ break;
+ case "mvvar":
+ VarStats statsToMove = _stats.remove(varName);
+ String newName = vinst.getInput2().getName();
+ if (statsToMove != null) statsToMove.varName = newName;
+ _stats.put(newName, statsToMove);
+ break;
+ case "rmvar":
+ for (CPOperand inputOperand: vinst.getInputs()) {
+ VarStats inputVar = _stats.remove(inputOperand.getName());
+ if (inputVar == null) continue; // inputVar == null for scalars
+ // actually remove from memory only if not referenced more than once
+ if (--inputVar.selfRefCount > 0) {
+ _stats.put(inputOperand.getName(), inputVar);
+ } else if (--inputVar.refCount < 1) {
+ removeFromMemory(inputVar);
+ }
+ }
+ break;
+ case "castdts":
+ VarStats scalarStats = new VarStats(vinst.getOutputVariableName(), null);
+ _stats.put(vinst.getOutputVariableName(), scalarStats);
+ break;
+ case "write":
+ String fileName = vinst.getInput2().isLiteral()? vinst.getInput2().getLiteral().getStringValue() : "hdfs_file";
String dataSource = IOCostUtils.getDataSource(fileName);
- varStats._fileInfo = new Object[]{dataSource, ((MetaDataFormat) vinst.getMetaData()).getFileFormat()};
- }
- _stats.put(varName, varStats);
- }
- else if ( opcode.equals("cpvar") ) {
- VarStats copiedStats = _stats.get(varName);
- _stats.put(vinst.getInput2().getName(), copiedStats);
- }
- else if ( opcode.equals("mvvar") ) {
- VarStats statsToMove = _stats.get(varName);
- _stats.remove(vinst.getInput1().getName());
- _stats.put(vinst.getInput2().getName(), statsToMove);
- }
- else if( opcode.equals("rmvar") ) {
- VarStats input =_stats.remove(varName);
- removeFromMemory(input);
+ String formatString = vinst.getInput3().getLiteral().getStringValue();
+ _stats.get(varName).fileInfo = new Object[] {dataSource, FileFormat.safeValueOf(formatString)};
+ break;
}
}
- else if( inst instanceof DataGenCPInstruction ){
+ else if (inst instanceof DataGenCPInstruction){
// variable already created at "createvar"
// now update the sparsity and set size estimate
String opcode = inst.getOpcode();
if (opcode.equals("rand")) {
DataGenCPInstruction dinst = (DataGenCPInstruction) inst;
- VarStats stat = _stats.get(dinst.getOutput().getName());
- stat._mc.setNonZeros((long) (stat.getCells()*dinst.getSparsity()));
- putInMemory(stat);
+ VarStats stat = getStats(dinst.getOutput().getName());
+ stat.characteristics.setNonZeros((long) (stat.getCells()*dinst.getSparsity()));
}
+ } else if (inst instanceof AggregateUnaryCPInstruction) {
+ // specific case to aid future dimensions inferring
+ String opcode = inst.getOpcode();
+ if (!(opcode.equals("nrow") || opcode.equals("ncol") || opcode.equals("length"))) {
+ return;
+ }
+ AggregateUnaryCPInstruction auinst = (AggregateUnaryCPInstruction) inst;
+ VarStats inputStats = getStats(auinst.input1.getName());
+ String outputName = auinst.getOutputVariableName();
+ VarStats outputStats;
+ if (opcode.equals("nrow")) {
+ if (inputStats.getM() < 0) return;
+ outputStats = new VarStats(String.valueOf(inputStats.getM()), null);
+ } else if (opcode.equals("ncol")) {
+ if (inputStats.getN() < 0) return;
+ outputStats = new VarStats(String.valueOf(inputStats.getN()), null);
+ } else { // if (opcode.equals("length"))
+ if (inputStats.getCells() < 0) return;
+ outputStats = new VarStats(String.valueOf(inputStats.getCells()), null);
+ }
+ _stats.put(outputName, outputStats);
}
- else if( inst instanceof FunctionCallCPInstruction )
- {
- FunctionCallCPInstruction finst = (FunctionCallCPInstruction) inst;
- for( String varname : finst.getBoundOutputParamNames() )
- _stats.put(varname, _unknownStats);
+ }
+
+ public double getTimeEstimateInst(Instruction inst) throws CostEstimationException {
+ double timeEstimate;
+ if (inst instanceof CPInstruction) {
+ timeEstimate = getTimeEstimateCPInst((CPInstruction)inst);
+ } else { // inst instanceof SPInstruction
+ timeEstimate = parseSPInst((SPInstruction) inst);
}
+ return timeEstimate;
}
/**
* Estimates the execution time of a single CP instruction
* following the formula <i>C(p) = T_w + max(T_r, T_c)</i> with:
+ * <ul>
* <li>T_w - instruction write (to mem.) time</li>
* <li>T_r - instruction read (to mem.) time</li>
* <li>T_c - instruction compute time</li>
+ * </ul>
*
- * @param pb ?
- * @param inst ?
- * @return
- * @throws CostEstimationException
+ * @param inst instruction for estimation
+ * @return estimated time in seconds
+ * @throws CostEstimationException ?
*/
- private double getTimeEstimateCPInst(ProgramBlock pb, CPInstruction inst) throws CostEstimationException {
- double ret = 0;
+ public double getTimeEstimateCPInst(CPInstruction inst) throws CostEstimationException {
+ double time = 0;
+ VarStats output = null;
if (inst instanceof VariableCPInstruction) {
String opcode = inst.getOpcode();
- VariableCPInstruction varInst = (VariableCPInstruction) inst;
- VarStats input = _stats.get(varInst.getInput1().getName());
+ VariableCPInstruction vinst = (VariableCPInstruction) inst;
+ VarStats input = null;
if (opcode.startsWith("cast")) {
- ret += getLoadTime(input); // disk I/O estimate
- double scanTime = IOCostUtils.getMemReadTime(input); // memory read cost
- double computeTime = getNFLOP_CPVariableInst(varInst, input) / CP_FLOPS;
- ret += Math.max(scanTime, computeTime);
- CPOperand outputOperand = varInst.getOutput();
- VarStats output = _stats.get(outputOperand.getName());
- putInMemory(output);
- ret += IOCostUtils.getMemWriteTime(input); // memory write cost
+ input = getStatsWithDefaultScalar(vinst.getInput1().getName());
+ output = getStatsWithDefaultScalar(vinst.getOutput().getName());
+ CPCostUtils.assignOutputMemoryStats(inst, output, input);
}
else if (opcode.equals("write")) {
- ret += getLoadTime(input); // disk I/O estimate
- String fileName = inst.getFilename();
- String dataSource = IOCostUtils.getDataSource(fileName);
- String formatString = varInst.getInput3().getLiteral().getStringValue();
- ret += getNFLOP_CPVariableInst(varInst, input) / CP_FLOPS; // compute time cost
- ret += IOCostUtils.getWriteTime(input.getM(), input.getN(), input.getS(),
- dataSource, Types.FileFormat.safeValueOf(formatString)); // I/O estimate
+ input = getStatsWithDefaultScalar(vinst.getInput1().getName());
+ time += IOCostUtils.getFileSystemWriteTime(input, driverMetrics); // I/O estimate
}
-
- return ret;
- }
- else if (inst instanceof DataGenCPInstruction) {
- DataGenCPInstruction randInst = (DataGenCPInstruction) inst;
- if( randInst.getOpcode().equals("rand") ) {
- long rlen = randInst.getRows();
- long clen = randInst.getCols();
- //int blen = randInst.getBlocksize();
- long nnz = (long) (randInst.getSparsity() * rlen * clen);
- return nnz; //TODO
- }
- else {
- //e.g., seq
- return 1;
- }
+ time += input == null? 0 : loadCPVarStatsAndEstimateTime(input);
+ time += CPCostUtils.getVariableInstTime(vinst, input, output, driverMetrics);
}
else if (inst instanceof UnaryCPInstruction) {
- // --- Operations associated with networking cost only ---
- // TODO: is somehow computational cost relevant for these operations
- if (inst instanceof PrefetchCPInstruction) {
- throw new DMLRuntimeException("TODO");
- } else if (inst instanceof BroadcastCPInstruction) {
- throw new DMLRuntimeException("TODO");
- } else if (inst instanceof EvictCPInstruction) {
- throw new DMLRuntimeException("Costing an instruction for GPU cache eviction is not supported.");
- }
+ UnaryCPInstruction uinst = (UnaryCPInstruction) inst;
+ output = getStatsWithDefaultScalar(uinst.getOutput().getName());
+ if (inst instanceof DataGenCPInstruction || inst instanceof StringInitCPInstruction) {
+ String[] s = InstructionUtils.getInstructionParts(uinst.getInstructionString());
+ VarStats rows = getStatsWithDefaultScalar(s[1]);
+ VarStats cols = getStatsWithDefaultScalar(s[2]);
+ CPCostUtils.assignOutputMemoryStats(inst, output, rows, cols);
+ time += CPCostUtils.getDataGenCPInstTime(uinst, output, driverMetrics);
+ } else {
+ // UnaryCPInstruction input can be any type of object
+ VarStats input = getStatsWithDefaultScalar(uinst.input1.getName());
+ // a few of the unary instructions take second optional argument of type matrix
+ VarStats weights = (uinst.input2 == null || uinst.input2.isScalar()) ? null : getStats(uinst.input2.getName());
- // opcodes that does not require estimation
- if (inst.getOpcode().equals("print")) {
- return 0;
- }
- UnaryCPInstruction unaryInst = (UnaryCPInstruction) inst;
- if (unaryInst.input1.isTensor())
- throw new DMLRuntimeException("Tensor is not supported for cost estimation");
- VarStats input = _stats.get(unaryInst.input1.getName());
- VarStats output = _stats.get(unaryInst.getOutput().getName());
+ if (inst instanceof IndexingCPInstruction) {
+ // weights = second input for leftIndex operations
+ IndexingCPInstruction idxInst = (IndexingCPInstruction) inst;
+ VarStats rowLower = getStatsWithDefaultScalar(idxInst.getRowLower().getName());
+ VarStats rowUpper = getStatsWithDefaultScalar(idxInst.getRowUpper().getName());
+ VarStats colLower = getStatsWithDefaultScalar(idxInst.getColLower().getName());
+ VarStats colUpper = getStatsWithDefaultScalar(idxInst.getColUpper().getName());
+ CPCostUtils.assignOutputMemoryStats(inst, output, input, weights, rowLower, rowUpper, colLower, colUpper);
+ } else if (inst instanceof ReorgCPInstruction && inst.getOpcode().equals("rsort")) {
+ ReorgCPInstruction reorgInst = (ReorgCPInstruction) inst;
+ VarStats ixRet = getStatsWithDefaultScalar(reorgInst.getIxRet().getName());
+ CPCostUtils.assignOutputMemoryStats(inst, output, input, ixRet);
+ } else {
+ CPCostUtils.assignOutputMemoryStats(inst, output, input);
+ }
- ret += getLoadTime(input);
- double scanTime = IOCostUtils.getMemReadTime(input);
- double computeTime = getNFLOP_CPUnaryInst(unaryInst, input, output) / CP_FLOPS;
- ret += Math.max(scanTime, computeTime);
- putInMemory(output);
- ret += IOCostUtils.getMemWriteTime(output);
- return ret;
+ time += loadCPVarStatsAndEstimateTime(input);
+ time += weights == null ? 0 : loadCPVarStatsAndEstimateTime(weights);
+ time += CPCostUtils.getUnaryInstTime(uinst, input, weights, output, driverMetrics);
+ }
}
else if (inst instanceof BinaryCPInstruction) {
- BinaryCPInstruction binInst = (BinaryCPInstruction) inst;
- if (binInst.input1.isFrame() || binInst.input2.isFrame())
- throw new DMLRuntimeException("Frame is not supported for cost estimation");
- VarStats input1 = _stats.get(binInst.input1.getName());
- VarStats input2 = _stats.get(binInst.input2.getName());
- VarStats output = _stats.get(binInst.output.getName());
-
- ret += getLoadTime(input1);
- ret += getLoadTime(input2);
- double scanTime = IOCostUtils.getMemReadTime(input1) + IOCostUtils.getMemReadTime(input2);
- double computeTime = getNFLOP_CPBinaryInst(binInst, input1, input2, output) / CP_FLOPS;
- ret += Math.max(scanTime, computeTime);
- putInMemory(output);
- ret += IOCostUtils.getMemWriteTime(output);
- return ret;
- }
- else if (inst instanceof AggregateTernaryCPInstruction) {
- AggregateTernaryCPInstruction aggInst = (AggregateTernaryCPInstruction) inst;
- VarStats input = _stats.get(aggInst.input1.getName());
- VarStats output = _stats.get(aggInst.getOutput().getName());
-
- ret += getLoadTime(input);
- double scanTime = IOCostUtils.getMemReadTime(input);
- double computeTime = (double) (6 * input.getCells()) / CP_FLOPS;
- ret += Math.max(scanTime, computeTime);
- putInMemory(output);
- ret += IOCostUtils.getMemWriteTime(output);
- return ret;
- }
- else if (inst instanceof TernaryFrameScalarCPInstruction) {
- // TODO: put some real implementation:
- // the idea is to take some worse case scenario since different mapping functionalities are possible
- // NOTE: maybe unite with AggregateTernaryCPInstruction since its similar but with factor of 6
- TernaryFrameScalarCPInstruction tInst = (TernaryFrameScalarCPInstruction) inst;
- VarStats input = _stats.get(tInst.input1.getName());
- VarStats output = _stats.get(tInst.getOutput().getName());
-
- ret += getLoadTime(input);
- double scanTime = IOCostUtils.getMemReadTime(input);
- double computeTime = (double) (4*input.getCells()) / CP_FLOPS; // 4 - dummy factor
- ret += Math.max(scanTime, computeTime);
- putInMemory(output);
- ret += IOCostUtils.getMemWriteTime(output);
- return ret;
- }
- else if (inst instanceof QuaternaryCPInstruction) {
- // TODO: put logical compute estimate (maybe putting a complexity factor)
- QuaternaryCPInstruction gInst = (QuaternaryCPInstruction) inst;
- VarStats input1 = _stats.get(gInst.input1.getName());
- VarStats input2 = _stats.get(gInst.input2.getName());
- VarStats input3 = _stats.get(gInst.input3.getName());
- VarStats input4 = _stats.get(gInst.getInput4().getName());
- VarStats output = _stats.get(gInst.getOutput().getName());
-
- ret += getLoadTime(input1) + getLoadTime(input2) + getLoadTime(input3) + getLoadTime(input4);
- double scanTime = IOCostUtils.getMemReadTime(input1)
- + IOCostUtils.getMemReadTime(input2)
- + IOCostUtils.getMemReadTime(input3)
- + IOCostUtils.getMemReadTime(input4);
- double computeTime = (double) (input1.getCells() * input2.getCells() + input3.getCells() + input4.getCells())
- / CP_FLOPS;
- ret += Math.max(scanTime, computeTime);
- putInMemory(output);
- ret += IOCostUtils.getMemWriteTime(output);
- return ret;
- }
- else if (inst instanceof ScalarBuiltinNaryCPInstruction) {
- return 1d / CP_FLOPS;
- }
- else if (inst instanceof MatrixBuiltinNaryCPInstruction) {
- MatrixBuiltinNaryCPInstruction mInst = (MatrixBuiltinNaryCPInstruction) inst;
- VarStats output = _stats.get(mInst.getOutput().getName());
- int numMatrices = 0;
- double scanTime = 0d;
- for (CPOperand operand : mInst.getInputs()) {
- if (operand.isMatrix()) {
- VarStats input = _stats.get(operand.getName());
- ret += getLoadTime(input);
- scanTime += IOCostUtils.getMemReadTime(input);
- numMatrices += 1;
- }
-
+ BinaryCPInstruction binst = (BinaryCPInstruction) inst;
+ VarStats input1 = getStatsWithDefaultScalar(binst.input1.getName());
+ VarStats input2 = getStatsWithDefaultScalar(binst.input2.getName());
+ VarStats weights = binst.input3 == null? null : getStatsWithDefaultScalar(binst.input3.getName());
+ output = getStatsWithDefaultScalar(binst.output.getName());
+ if (inst instanceof AggregateBinaryCPInstruction) {
+ AggregateBinaryCPInstruction aggBinInst = (AggregateBinaryCPInstruction) inst;
+ VarStats transposeLeft = new VarStats(String.valueOf(aggBinInst.transposeLeft), null);
+ VarStats transposeRight = new VarStats(String.valueOf(aggBinInst.transposeRight), null);
+ CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2, transposeLeft, transposeRight);
+ } else {
+ CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2);
}
- double computeTime = getNFLOP_CPMatrixBuiltinNaryInst(mInst, numMatrices, output) / CP_FLOPS;
- ret += Math.max(scanTime, computeTime);
- putInMemory(output);
- ret += IOCostUtils.getMemWriteTime(output);
- return ret;
- }
- else if (inst instanceof EvalNaryCPInstruction) {
- throw new RuntimeException("To be implemented later");
- }
- else if (inst instanceof MultiReturnBuiltinCPInstruction) {
- MultiReturnBuiltinCPInstruction mrbInst = (MultiReturnBuiltinCPInstruction) inst;
- VarStats input = _stats.get(mrbInst.input1.getName());
- ret += getLoadTime(input);
- double scanTime = IOCostUtils.getMemReadTime(input);
- double computeTime = getNFLOP_CPMultiReturnBuiltinInst(mrbInst, input) / CP_FLOPS;
- ret += Math.max(scanTime, computeTime);
- for (CPOperand operand : mrbInst.getOutputs()) {
- VarStats output = _stats.get(operand.getName());
- putInMemory(output);
- ret += IOCostUtils.getMemWriteTime(output);
- }
- return ret;
- }
- else if (inst instanceof CtableCPInstruction) {
- CtableCPInstruction ctInst = (CtableCPInstruction) inst;
- VarStats input1 = _stats.get(ctInst.input1.getName());
- VarStats input2 = _stats.get(ctInst.input2.getName());
- VarStats input3 = _stats.get(ctInst.input3.getName());
- //VarStats output = _stats.get(ctInst.getOutput().getName());
-
- ret += getLoadTime(input1) + getLoadTime(input2) + getLoadTime(input3);
- double scanTime = IOCostUtils.getMemReadTime(input1)
- + IOCostUtils.getMemReadTime(input2)
- + IOCostUtils.getMemReadTime(input3);
- double computeTime = (double) input1.getCellsWithSparsity() / CP_FLOPS;
- ret += Math.max(scanTime, computeTime);
- // TODO: figure out what dimensions to assign to the output matrix stats 'output'
- throw new DMLRuntimeException("Operation "+inst.getOpcode()+" is not supported yet due to a unpredictable output");
- }
- else if (inst instanceof PMMJCPInstruction) {
- PMMJCPInstruction pmmInst = (PMMJCPInstruction) inst;
- VarStats input1 = _stats.get(pmmInst.input1.getName());
- VarStats input2 = _stats.get(pmmInst.input2.getName());
- VarStats output = _stats.get(pmmInst.getOutput().getName());
-
- ret += getLoadTime(input1) + getLoadTime(input2);
- double scanTime = IOCostUtils.getMemReadTime(input1) + IOCostUtils.getMemReadTime(input2);
- double computeTime = input1.getCells() * input2.getCellsWithSparsity() / CP_FLOPS;
- ret += Math.max(scanTime, computeTime);
- putInMemory(output);
- ret += IOCostUtils.getMemWriteTime(output);
- return ret;
+ time += loadCPVarStatsAndEstimateTime(input1);
+ time += loadCPVarStatsAndEstimateTime(input2);
+ time += weights == null? 0 : loadCPVarStatsAndEstimateTime(weights);
+ time += CPCostUtils.getBinaryInstTime(binst, input1, input2, weights, output, driverMetrics);
}
else if (inst instanceof ParameterizedBuiltinCPInstruction) {
- ParameterizedBuiltinCPInstruction paramInst = (ParameterizedBuiltinCPInstruction) inst;
- String[] parts = InstructionUtils.getInstructionParts(inst.toString());
- VarStats input = _stats.get( parts[1].substring(7).replaceAll(Lop.VARIABLE_NAME_PLACEHOLDER, "") );
- VarStats output = _stats.get( parts[parts.length-1] );
+ if (inst instanceof ParamservBuiltinCPInstruction) {
+ throw new RuntimeException("ParamservBuiltinCPInstruction is not supported for estimation");
+ }
+ ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst;
- ret += getLoadTime(input);
- double scanTime = IOCostUtils.getMemReadTime(input);
- double computeTime = getNFLOP_CPParameterizedBuiltinInst(paramInst, input, output) / CP_FLOPS;
- ret += Math.max(scanTime, computeTime);
+ VarStats input1 = getParameterizedBuiltinParamStats("target", pinst.getParameterMap(), true); // required
+ VarStats input2 = null; // optional
+ switch (inst.getOpcode()) {
+ case "rmempty":
+ input2 = getParameterizedBuiltinParamStats("select", pinst.getParameterMap(), false);
+ break;
+ case "contains":
+ input2 = getParameterizedBuiltinParamStats("pattern", pinst.getParameterMap(), false);
+ break;
+ case "groupedagg":
+ input2 = getParameterizedBuiltinParamStats("groups", pinst.getParameterMap(), false);
+ break;
+ }
+ output = getStatsWithDefaultScalar(pinst.getOutputVariableName());
+ CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2);
+
+ time += input1 != null? loadCPVarStatsAndEstimateTime(input1) : 0;
+ time += input2 != null? loadCPVarStatsAndEstimateTime(input2) : 0;
+ time += CPCostUtils.getParameterizedBuiltinInstTime(pinst, input1, output, driverMetrics);
+ } else if (inst instanceof MultiReturnBuiltinCPInstruction) {
+ MultiReturnBuiltinCPInstruction mrbinst = (MultiReturnBuiltinCPInstruction) inst;
+ VarStats input = getStats(mrbinst.input1.getName());
+ VarStats[] outputs = new VarStats[mrbinst.getOutputs().size()];
+ int i = 0;
+ for (CPOperand operand : mrbinst.getOutputs()) {
+ if (!operand.isMatrix()) {
+ throw new DMLRuntimeException("MultiReturnBuiltinCPInstruction expects only matrix output objects");
+ }
+ VarStats current = getStats(operand.getName());
+ outputs[i] = current;
+ i++;
+ }
+ // input and outputs switched on purpose: exclusive behaviour for this instruction
+ CPCostUtils.assignOutputMemoryStats(inst, input, outputs);
+ for (VarStats current : outputs) putInMemory(current);
+
+ time += loadCPVarStatsAndEstimateTime(input);
+ time += CPCostUtils.getMultiReturnBuiltinInstTime(mrbinst, input, outputs, driverMetrics);
+ // the only place to return directly here (output put in memory already)
+ return time;
+ }
+ else if (inst instanceof ComputationCPInstruction) {
+ if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction || inst instanceof CompressionCPInstruction || inst instanceof DeCompressionCPInstruction) {
+ throw new RuntimeException(inst.getClass().getName() + " is not supported for estimation");
+ }
+ ComputationCPInstruction cinst = (ComputationCPInstruction) inst;
+ VarStats input1 = getStatsWithDefaultScalar(cinst.input1.getName()); // 1 input: AggregateTernaryCPInstruction
+ // in general only the first input operand is guaranteed initialized
+ // assume they can be also scalars (often operands are some literal or scalar arguments not related to the cost estimation)
+ VarStats input2 = cinst.input2 == null? null : getStatsWithDefaultScalar(cinst.input2.getName()); // 2 inputs: PMMJCPInstruction
+ VarStats input3 = cinst.input3 == null? null : getStatsWithDefaultScalar(cinst.input3.getName()); // 3 inputs: TernaryCPInstruction, CtableCPInstruction
+ VarStats input4 = cinst.input4 == null? null : getStatsWithDefaultScalar(cinst.input4.getName()); // 4 inputs (possibly): QuaternaryCPInstruction
+ output = getStatsWithDefaultScalar(cinst.getOutput().getName());
+ if (inst instanceof CtableCPInstruction) {
+ CtableCPInstruction tableInst = (CtableCPInstruction) inst;
+ VarStats outDim1 = getCTableDim(tableInst.getOutDim1());
+ VarStats outDim2 = getCTableDim(tableInst.getOutDim2());
+ CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2, outDim1, outDim2);
+ } else {
+ CPCostUtils.assignOutputMemoryStats(inst, output, input1, input2, input3, input4);
+ }
+
+ time += loadCPVarStatsAndEstimateTime(input1);
+ time += input2 == null? 0 : loadCPVarStatsAndEstimateTime(input2);
+ time += input3 == null? 0 : loadCPVarStatsAndEstimateTime(input3);
+ time += input4 == null? 0 : loadCPVarStatsAndEstimateTime(input4);
+ time += CPCostUtils.getComputationInstTime(cinst, input1, input2, input3, input4, output, driverMetrics);
+ }
+ else if (inst instanceof BuiltinNaryCPInstruction) {
+ BuiltinNaryCPInstruction bninst = (BuiltinNaryCPInstruction) inst;
+ output = getStatsWithDefaultScalar(bninst.getOutput().getName());
+ // putInMemory(output);
+ if (bninst instanceof ScalarBuiltinNaryCPInstruction) {
+ return CPCostUtils.getBuiltinNaryInstTime(bninst, null, output, driverMetrics);
+ }
+ VarStats[] inputs = new VarStats[bninst.getInputs().length];
+ int i = 0;
+ for (CPOperand operand : bninst.getInputs()) {
+ if (operand.isMatrix()) {
+ VarStats input = getStatsWithDefaultScalar(operand.getName());
+ time += loadCPVarStatsAndEstimateTime(input);
+ inputs[i] = input;
+ i++;
+ }
+ }
+ // trim the arrays to its actual size
+ inputs = Arrays.copyOf(inputs, i + 1);
+ CPCostUtils.assignOutputMemoryStats(inst, output, inputs);
+ time += CPCostUtils.getBuiltinNaryInstTime(bninst, inputs, output, driverMetrics);
+ }
+ else { // SqlCPInstruction
+ throw new RuntimeException(inst.getClass().getName() + " is not supported for estimation");
+ }
+
+ if (output != null)
putInMemory(output);
- ret += IOCostUtils.getMemWriteTime(output);
+ return time;
+ }
+
+ /**
+ * Parse a Spark instruction, and it stores the corresponding
+ * cost for computing the output variable in the RDD statistics'
+ * object related to that variable.
+ * This method is responsible for initializing the corresponding
+ * {@code RDDStats} object for each output variable, including for
+ * outputs that are explicitly brought back to CP (Spark action within the instruction).
+ * It returns the time estimate only for those instructions that bring the
+ * output explicitly to CP. For the rest, the estimated time (cost) is
+ * stored as part of the corresponding RDD statistics, emulating the
+ * lazy evaluation execution of Spark.
+ *
+ * @param inst Spark instruction for parsing
+ * @return if explicit action, estimated time in seconds, else always 0
+ * @throws CostEstimationException ?
+ */
+ public double parseSPInst(SPInstruction inst) throws CostEstimationException {
+ /* Logic for the parallelization factors:
+ * the given executor metrics relate to peak performance per node,
+ * utilizing all the resources available, but the Spark operations
+ * are executed by several tasks per node so the execution/read time
+ * per operation is the potential execution time that ca be achieved by
+ * using the full node resources divided by the with the number of
+ * nodes running tasks for reading but then divided to the actual number of
+ * tasks to account that if on a node not all the cores are reading
+ * then not the full resources are utilized.
+ */
+ VarStats output;
+ if (inst instanceof ReblockSPInstruction || inst instanceof CSVReblockSPInstruction || inst instanceof LIBSVMReblockSPInstruction) {
+ UnarySPInstruction uinst = (UnarySPInstruction) inst;
+ VarStats input = getStats((uinst).input1.getName());
+ output = getStats((uinst).getOutputVariableName());
+ SparkCostUtils.assignOutputRDDStats(inst, output, input);
+
+ output.fileInfo = input.fileInfo;
+ // the resulting binary rdd is being hash-partitioned after the reblock
+ output.rddStats.hashPartitioned = true;
+ output.rddStats.cost = SparkCostUtils.getReblockInstTime(inst.getOpcode(), input, output, executorMetrics);
+ } else if (inst instanceof CheckpointSPInstruction) {
+ CheckpointSPInstruction cinst = (CheckpointSPInstruction) inst;
+ VarStats input = getStats(cinst.input1.getName());
+ double loadTime = loadRDDStatsAndEstimateTime(input);
+
+ output = getStats(cinst.getOutputVariableName());
+ SparkCostUtils.assignOutputRDDStats(inst, output, input);
+ output.rddStats.checkpoint = true;
+ // assume the rdd object is only marked as checkpoint;
+ // adding spilling or serializing cost is skipped
+ output.rddStats.cost = loadTime;
+ } else if (inst instanceof RandSPInstruction) {
+ // Rand instruction takes no RDD input;
+ RandSPInstruction rinst = (RandSPInstruction) inst;
+ String opcode = rinst.getOpcode();
+ int randType = -1; // default for non-random object generation operations
+ if (opcode.equals(RAND_OPCODE) || opcode.equals(FRAME_OPCODE)) {
+ if (rinst.getMinValue() == 0d && rinst.getMaxValue() == 0d) { // empty matrix
+ randType = 0;
+ } else if (rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue()) { // allocate, array fill
+ randType = 1;
+ } else { // full rand
+ randType = 2;
+ }
+ }
+ output = getStats(rinst.output.getName());
+ SparkCostUtils.assignOutputRDDStats(inst, output);
+
+ output.rddStats.cost = getRandInstTime(opcode, randType, output, executorMetrics);
+ } else if (inst instanceof AggregateUnarySPInstruction || inst instanceof AggregateUnarySketchSPInstruction) {
+ UnarySPInstruction auinst = (UnarySPInstruction) inst;
+ VarStats input = getStats((auinst).input1.getName());
+ double loadTime = loadRDDStatsAndEstimateTime(input);
+
+ output = getStats((auinst).getOutputVariableName());
+ SparkCostUtils.assignOutputRDDStats(inst, output, input);
+
+ output.rddStats.cost = loadTime + SparkCostUtils.getAggUnaryInstTime(auinst, input, output, executorMetrics);
+ } else if (inst instanceof IndexingSPInstruction) {
+ IndexingSPInstruction ixdinst = (IndexingSPInstruction) inst;
+ boolean isLeftCacheType = (inst instanceof MatrixIndexingSPInstruction &&
+ ((MatrixIndexingSPInstruction) ixdinst).getLixType() == LeftIndex.LixCacheType.LEFT);
+ VarStats input1; // always assigned
+ VarStats input2 = null; // assigned only if case of indexing
+ double loadTime = 0;
+ if (ixdinst.getOpcode().toLowerCase().contains("left")) {
+ if (isLeftCacheType) {
+ input1 = getStats(ixdinst.input2.getName());
+ input2 = getStats(ixdinst.input1.getName());
+ } else {
+ input1 = getStats(ixdinst.input1.getName());
+ input2 = getStats(ixdinst.input2.getName());
+ }
+
+ if (ixdinst.getOpcode().equals(LeftIndex.OPCODE)) {
+ loadTime += loadRDDStatsAndEstimateTime(input2);
+ } else { // mapLeftIndex
+ loadTime += loadCPVarStatsAndEstimateTime(input2);
+ }
+ } else {
+ input1 = getStats(ixdinst.input1.getName());
+ }
+ loadTime += loadRDDStatsAndEstimateTime(input1);
+
+ VarStats rowLower = getStatsWithDefaultScalar(ixdinst.getRowLower().getName());
+ VarStats rowUpper = getStatsWithDefaultScalar(ixdinst.getRowUpper().getName());
+ VarStats colLower = getStatsWithDefaultScalar(ixdinst.getColLower().getName());
+ VarStats colUpper = getStatsWithDefaultScalar(ixdinst.getColUpper().getName());
+ output = getStats(ixdinst.getOutputVariableName());
+ SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2, rowLower, rowUpper, colLower, colUpper);
+
+ output.rddStats.cost = loadTime +
+ SparkCostUtils.getIndexingInstTime(ixdinst, input1, input2, output, driverMetrics, executorMetrics);
+ } else if (inst instanceof UnarySPInstruction) { // general unary handling body; put always after all the rest blocks for unary
+ UnarySPInstruction uinst = (UnarySPInstruction) inst;
+ VarStats input = getStats((uinst).input1.getName());
+ double loadTime = loadRDDStatsAndEstimateTime(input);
+ output = getStats((uinst).getOutputVariableName());
+
+ if (uinst instanceof UnaryMatrixSPInstruction || inst instanceof UnaryFrameSPInstruction) {
+ SparkCostUtils.assignOutputRDDStats(inst, output, input);
+ output.rddStats.cost = loadTime + SparkCostUtils.getUnaryInstTime(uinst.getOpcode(), input, output, executorMetrics);
+ } else if (uinst instanceof ReorgSPInstruction || inst instanceof MatrixReshapeSPInstruction) {
+ if (uinst instanceof ReorgSPInstruction && uinst.getOpcode().equals("rsort")) {
+ ReorgSPInstruction reorgInst = (ReorgSPInstruction) inst;
+ VarStats ixRet = getStatsWithDefaultScalar(reorgInst.getIxRet().getName());
+ SparkCostUtils.assignOutputRDDStats(inst, output, input, ixRet);
+ } else {
+ SparkCostUtils.assignOutputRDDStats(inst, output, input);
+ }
+ output.rddStats.cost = loadTime + SparkCostUtils.getReorgInstTime(uinst, input, output, executorMetrics);
+ } else if (uinst instanceof TsmmSPInstruction || inst instanceof Tsmm2SPInstruction) {
+ SparkCostUtils.assignOutputRDDStats(inst, output, input);
+ output.rddStats.cost = loadTime + SparkCostUtils.getTSMMInstTime(uinst, input, output, driverMetrics, executorMetrics);
+ } else if (uinst instanceof CentralMomentSPInstruction) {
+ VarStats weights = null;
+ if (uinst.input3 != null) {
+ weights = getStats(uinst.input2.getName());
+ loadTime += loadRDDStatsAndEstimateTime(weights);
+ }
+ SparkCostUtils.assignOutputRDDStats(inst, output, input, weights);
+ output.rddStats.cost = loadTime +
+ SparkCostUtils.getCentralMomentInstTime((CentralMomentSPInstruction) uinst, input, weights, output, executorMetrics);
+ } else if (inst instanceof CastSPInstruction) {
+ SparkCostUtils.assignOutputRDDStats(inst, output, input);
+ output.rddStats.cost = loadTime + SparkCostUtils.getCastInstTime((CastSPInstruction) inst, input, output, executorMetrics);
+ } else if (inst instanceof QuantileSortSPInstruction) {
+ VarStats weights = null;
+ if (uinst.input2 != null) {
+ weights = getStats(uinst.input2.getName());
+ loadTime += loadRDDStatsAndEstimateTime(weights);
+ }
+ SparkCostUtils.assignOutputRDDStats(inst, output, input, weights);
+ output.rddStats.cost = loadTime +
+ SparkCostUtils.getQSortInstTime((QuantileSortSPInstruction) uinst, input, weights, output, executorMetrics);
+ } else {
+ throw new RuntimeException("Unsupported Unary Spark instruction of type " + inst.getClass().getName());
+ }
+ } else if (inst instanceof BinaryFrameFrameSPInstruction || inst instanceof BinaryFrameMatrixSPInstruction || inst instanceof BinaryMatrixMatrixSPInstruction || inst instanceof BinaryMatrixScalarSPInstruction) {
+ BinarySPInstruction binst = (BinarySPInstruction) inst;
+ VarStats input1 = getStatsWithDefaultScalar((binst).input1.getName());
+ VarStats input2 = getStatsWithDefaultScalar((binst).input2.getName());
+ // handle input rdd loading
+ double loadTime = loadRDDStatsAndEstimateTime(input1);
+ if (inst instanceof BinaryMatrixBVectorSPInstruction) {
+ loadTime += loadCPVarStatsAndEstimateTime(input2);
+ } else {
+ loadTime += loadRDDStatsAndEstimateTime(input2);
+ }
+
+ output = getStats((binst).getOutputVariableName());
+ SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2);
+
+ output.rddStats.cost = loadTime +
+ SparkCostUtils.getBinaryInstTime(inst, input1, input2, output, driverMetrics, executorMetrics);
+ } else if (inst instanceof AppendSPInstruction) {
+ AppendSPInstruction ainst = (AppendSPInstruction) inst;
+ VarStats input1 = getStats(ainst.input1.getName());
+ double loadTime = loadRDDStatsAndEstimateTime(input1);
+ VarStats input2 = getStats(ainst.input2.getName());
+ if (ainst instanceof AppendMSPInstruction) {
+ loadTime += loadCPVarStatsAndEstimateTime(input2);
+ } else {
+ loadTime += loadRDDStatsAndEstimateTime(input2);
+ }
+ output = getStats(ainst.getOutputVariableName());
+ SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2);
+
+ output.rddStats.cost = loadTime + SparkCostUtils.getAppendInstTime(ainst, input1, input2, output, driverMetrics, executorMetrics);
+ } else if (inst instanceof AggregateBinarySPInstruction || inst instanceof PmmSPInstruction || inst instanceof PMapmmSPInstruction || inst instanceof ZipmmSPInstruction) {
+ BinarySPInstruction binst = (BinarySPInstruction) inst;
+ VarStats input1, input2;
+ double loadTime = 0;
+ if (binst instanceof MapmmSPInstruction || binst instanceof PmmSPInstruction) {
+ MapMult.CacheType cacheType = binst instanceof MapmmSPInstruction?
+ ((MapmmSPInstruction) binst).getCacheType() :
+ ((PmmSPInstruction) binst).getCacheType();
+ if (cacheType.isRight()) {
+ input1 = getStats(binst.input1.getName());
+ input2 = getStats(binst.input2.getName());
+ } else {
+ input1 = getStats(binst.input2.getName());
+ input2 = getStats(binst.input1.getName());
+ }
+ loadTime += loadRDDStatsAndEstimateTime(input1);
+ loadTime += loadCPVarStatsAndEstimateTime(input2);
+ } else {
+ input1 = getStats(binst.input1.getName());
+ input2 = getStats(binst.input2.getName());
+ loadTime += loadRDDStatsAndEstimateTime(input1);
+ loadTime += loadRDDStatsAndEstimateTime(input2);
+ }
+ output = getStats(binst.getOutputVariableName());
+ SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2);
+
+ output.rddStats.cost = loadTime +
+ SparkCostUtils.getMatMulInstTime(binst, input1, input2, output, driverMetrics, executorMetrics);
+ } else if (inst instanceof MapmmChainSPInstruction) {
+ MapmmChainSPInstruction mmchaininst = (MapmmChainSPInstruction) inst;
+ VarStats input1 = getStats(mmchaininst.input1.getName());
+ VarStats input2 = getStats(mmchaininst.input1.getName());
+ VarStats input3 = null;
+ double loadTime = loadRDDStatsAndEstimateTime(input1) + loadCPVarStatsAndEstimateTime(input2);
+ if (mmchaininst.input3 != null) {
+ input3 = getStats(mmchaininst.input3.getName());
+ loadTime += loadCPVarStatsAndEstimateTime(input3);
+ }
+ output = getStats(mmchaininst.output.getName());
+ SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2, input3);
+
+ output.rddStats.cost = loadTime +
+ SparkCostUtils.getMatMulChainInstTime(mmchaininst, input1, input2, input3, output, driverMetrics, executorMetrics);
+ } else if (inst instanceof CtableSPInstruction) {
+ CtableSPInstruction tableInst = (CtableSPInstruction) inst;
+ VarStats input1 = getStats(tableInst.input1.getName());
+ VarStats input2 = getStatsWithDefaultScalar(tableInst.input2.getName());
+ VarStats input3 = getStatsWithDefaultScalar(tableInst.input3.getName());
+ double loadTime = loadRDDStatsAndEstimateTime(input1) +
+ loadRDDStatsAndEstimateTime(input2) + loadRDDStatsAndEstimateTime(input3);
+
+ output = getStats(tableInst.getOutputVariableName());
+ VarStats outDim1 = getCTableDim(tableInst.getOutDim1());
+ VarStats outDim2 = getCTableDim(tableInst.getOutDim2());
+ // third input not relevant for assignment (dimensions inferring)
+ SparkCostUtils.assignOutputRDDStats(inst, output, input1, input2, outDim1, outDim2);
+
+ output.rddStats.cost = loadTime +
+ SparkCostUtils.getCtableInstTime(tableInst, input1, input2, input3, output, executorMetrics);
+ } else if (inst instanceof ParameterizedBuiltinSPInstruction) {
+ ParameterizedBuiltinSPInstruction paramInst = (ParameterizedBuiltinSPInstruction) inst;
+
+ VarStats input1 = getParameterizedBuiltinParamStats("target", paramInst.getParameterMap(), true); // required
+ double loadTime = input1 != null? loadRDDStatsAndEstimateTime(input1) : 0;
+ VarStats input2 = null; // optional
+ switch (inst.getOpcode()) {
+ case "rmempty":
+ input2 = getParameterizedBuiltinParamStats("offset", paramInst.getParameterMap(), false);
+ if (Boolean.parseBoolean(paramInst.getParameterMap().get("bRmEmptyBC"))) {
+ loadTime += input2 != null? loadCPVarStatsAndEstimateTime(input2) : 0; // broadcast
+ } else {
+ loadTime += input2 != null? loadRDDStatsAndEstimateTime(input2) : 0;
+ }
+ break;
+ case "contains":
+ input2 = getParameterizedBuiltinParamStats("pattern", paramInst.getParameterMap(), false);
+ break;
+ case "groupedagg":
+ input2 = getParameterizedBuiltinParamStats("groups", paramInst.getParameterMap(), false);
+ // here is needed also a third parameter in some cases
+ break;
+ }
+
+ output = getStatsWithDefaultScalar(paramInst.getOutputVariableName());
+ SparkCostUtils.assignOutputRDDStats(inst, output, input1);
+
+ output.rddStats.cost = loadTime + SparkCostUtils.getParameterizedBuiltinInstTime(paramInst, input1, input2, output,
+ driverMetrics, executorMetrics);
+ } else if (inst instanceof WriteSPInstruction) {
+ WriteSPInstruction wInst = (WriteSPInstruction) inst;
+ VarStats input = getStats(wInst.input1.getName());
+ double loadTime = loadRDDStatsAndEstimateTime(input);
+ // extract and assign all needed parameters for writing a file
+ String fileName = wInst.getInput2().isLiteral()? wInst.getInput2().getLiteral().getStringValue() : "hdfs_file";
+ String dataSource = IOCostUtils.getDataSource(fileName); // "hadfs_file" -> "hdfs"
+ String formatString = wInst.getInput3().isLiteral()? wInst.getInput3().getLiteral().getStringValue() : "text";
+ input.fileInfo = new Object[] {dataSource, FileFormat.safeValueOf(formatString)};
+ // return time estimate here since no corresponding RDD statistics exist
+ return loadTime + IOCostUtils.getHadoopWriteTime(input, executorMetrics); // I/O estimate
+ }
+// else if (inst instanceof CumulativeOffsetSPInstruction) {
+//
+// } else if (inst instanceof CovarianceSPInstruction) {
+//
+// } else if (inst instanceof QuantilePickSPInstruction) {
+//
+// } else if (inst instanceof TernarySPInstruction) {
+//
+// } else if (inst instanceof AggregateTernarySPInstruction) {
+//
+// } else if (inst instanceof QuaternarySPInstruction) {
+//
+// }
+ else {
+ throw new RuntimeException("Unsupported instruction: " + inst.getOpcode());
+ }
+ // output.rdd should be always initialized at this point
+ if (output.rddStats.isCollected) {
+ if (!output.isScalar()) {
+ output.allocatedMemory = OptimizerUtils.estimateSizeExactSparsity(output.characteristics);
+ putInMemory(output);
+ }
+ double ret = output.rddStats.cost;
+ output.rddStats = null;
return ret;
}
- else if( inst instanceof FunctionCallCPInstruction )
- {
- FunctionCallCPInstruction finst = (FunctionCallCPInstruction)inst;
- //TODO recursive function calls and
- Program prog = pb.getProgram();
- FunctionProgramBlock fpb = prog.getFunctionProgramBlock(
- finst.getNamespace(), finst.getFunctionName());
- return getTimeEstimatePB(fpb);
- }
- else if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction) {
- throw new DMLRuntimeException("MultiReturnParametrized built-in instructions are not supported.");
- }
- else if (inst instanceof CompressionCPInstruction || inst instanceof DeCompressionCPInstruction) {
- throw new DMLRuntimeException("(De)Compression instructions are not supported yet.");
- }
- else if (inst instanceof SqlCPInstruction) {
- throw new DMLRuntimeException("SQL instructions are not supported.");
- }
- System.out.println("Unsupported instruction: " + inst.getOpcode());
- return 1;
- }
- private double getNFLOP_CPVariableInst(VariableCPInstruction inst, VarStats input) throws CostEstimationException {
- switch (inst.getOpcode()) {
- case "write":
- String fmtStr = inst.getInput3().getLiteral().getStringValue();
- Types.FileFormat fmt = Types.FileFormat.safeValueOf(fmtStr);
- double xwrite = fmt.isTextFormat() ? DEFAULT_NFLOP_TEXT_IO : DEFAULT_NFLOP_CP;
- return input.getCellsWithSparsity() * xwrite;
- case "cast_as_matrix":
- case "cast_as_frame":
- return input.getCells();
- default:
- return DEFAULT_NFLOP_CP;
- }
+ return 0;
}
- private double getNFLOP_CPUnaryInst(UnaryCPInstruction inst, VarStats input, VarStats output) throws CostEstimationException {
- String opcode = inst.getOpcode();
- // --- Operations for data generation ---
- if( inst instanceof DataGenCPInstruction ) {
- if (opcode.equals(DataGen.RAND_OPCODE)) {
- DataGenCPInstruction rinst = (DataGenCPInstruction) inst;
- if( rinst.getMinValue() == 0.0 && rinst.getMaxValue() == 0.0 )
- return DEFAULT_NFLOP_CP; // empty matrix
- else if( rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue() )
- return 8.0 * output.getCells();
- else { // full rand
- if (rinst.getSparsity() == 1.0)
- return 32.0 * output.getCells() + 8.0 * output.getCells();//DENSE gen (incl allocate)
- if (rinst.getSparsity() < MatrixBlock.SPARSITY_TURN_POINT)
- return 3.0 * output.getCellsWithSparsity() + 24.0 * output.getCellsWithSparsity(); //SPARSE gen (incl allocate)
- return 2.0 * output.getCells() + 8.0 * output.getCells(); //DENSE gen (incl allocate)
- }
- } else if (opcode.equals(DataGen.SEQ_OPCODE)) {
- return DEFAULT_NFLOP_CP * output.getCells();
- } else {
- throw new RuntimeException("To be implemented later");
- }
+ public double getTimeEstimateSparkJob(VarStats varToCollect) {
+ if (varToCollect.rddStats == null) {
+ throw new RuntimeException("Missing RDD statistics for estimating execution time for Spark Job");
}
- else if( inst instanceof StringInitCPInstruction ) {
- return DEFAULT_NFLOP_CP * output.getCells();
- }
- // --- General unary ---
-// if (input == null)
-// input = _scalarStats; // TODO: consider if needed: if yes -> stats garbage collections?
-
- if (inst instanceof MMTSJCPInstruction) {
- MMTSJ.MMTSJType type = ((MMTSJCPInstruction) inst).getMMTSJType();
- if (type.isLeft()) {
- if (input.isSparse()) {
- return input.getM() * input.getN() * input.getS() * input.getN() * input.getS() / 2;
- } else {
- return input.getM() * input.getN() * input.getS() * input.getN() / 2;
- }
- } else {
- throw new RuntimeException("To be implemented later");
- }
- } else if (inst instanceof AggregateUnaryCPInstruction) {
- AggregateUnaryCPInstruction uainst = (AggregateUnaryCPInstruction) inst;
- AggregateUnaryCPInstruction.AUType autype = uainst.getAUType();
- if (autype != AggregateUnaryCPInstruction.AUType.DEFAULT) {
- switch (autype) {
- case NROW:
- case NCOL:
- case LENGTH:
- return DEFAULT_NFLOP_NOOP;
- case COUNT_DISTINCT:
- case COUNT_DISTINCT_APPROX:
- // TODO: get real cost
- return input.getCells();
- case UNIQUE:
- // TODO: get real cost
- return input.getCells();
- case LINEAGE:
- // TODO: get real cost
- return DEFAULT_NFLOP_NOOP;
- case EXISTS:
- // TODO: get real cost
- return 1;
- default:
- // NOTE: not reachable - only for consistency
- return 0;
- }
- } else {
- int k = getComputationFactorUAOp(opcode);
- if (opcode.equals("cm")) {
- // TODO: extract attribute first and implement the logic then (CentralMomentCPInstruction)
- throw new RuntimeException("Not implemented yet.");
- } else if (opcode.equals("ua+") || opcode.equals("uar+") || opcode.equals("uac+")) {
- return k*input.getCellsWithSparsity();
- } else { // NOTE: assumes all other cases were already handled properly
- return (input!=null)?k*input.getCells() : 1;
- }
- }
- } else if(inst instanceof UnaryScalarCPInstruction) {
- // TODO: consider if 1 is always reasonable
- return 1;
- } else if(inst instanceof UnaryFrameCPInstruction) {
- switch (opcode) {
- case "typeOf":
- return 1;
- case "detectSchema":
- // TODO: think of a real static cost
- return 1;
- case "colnames":
- // TODO: is the number of the column reasonable result?
- return input.getN();
- }
- }else if (inst instanceof UnaryMatrixCPInstruction){
- if (opcode.equals("print"))
- return 1;
- else if (opcode.equals("inverse")) {
- // TODO: implement
- return 0;
- } else if (opcode.equals("cholesky")) {
- // TODO: implement
- return 0;
- }
- // NOTE: What is xbu?
- double xbu = 1; //default for all ops
- if( opcode.equals("plogp") ) xbu = 2;
- else if( opcode.equals("round") ) xbu = 4;
- switch (opcode) { //opcodes: exp, abs, sin, cos, tan, sign, sqrt, plogp, print, round, sprop, sigmoid
- case "sin": case "tan": case "round": case "abs":
- case "sqrt": case "sprop": case "sigmoid": case "sign":
- return xbu * input.getCellsWithSparsity();
- default:
- // TODO: does that apply to all valid unary matrix operators
- return xbu * input.getCells();
- }
- } else if (inst instanceof ReorgCPInstruction || inst instanceof ReshapeCPInstruction) {
- return input.getCellsWithSparsity();
- } else if (inst instanceof IndexingCPInstruction) {
- // NOTE: I doubt that this is formula for the cost is correct
- if (opcode.equals(RightIndex.OPCODE)) {
- // TODO: check correctness since I changed the initial formula to not use input 2
- return DEFAULT_NFLOP_CP * input.getCellsWithSparsity();
- } else if (opcode.equals(LeftIndex.OPCODE)) {
- VarStats indexMatrixStats = _stats.get(inst.input2.getName());
- return DEFAULT_NFLOP_CP * input.getCellsWithSparsity()
- + 2 * DEFAULT_NFLOP_CP * indexMatrixStats.getCellsWithSparsity();
- }
- } else if (inst instanceof MMChainCPInstruction) {
- // NOTE: reduction by factor 2 because matrix mult better than average flop count
- // (mmchain essentially two matrix-vector muliplications)
- return (2+2) * input.getCellsWithSparsity() / 2;
- } else if (inst instanceof UaggOuterChainCPInstruction) {
- // TODO: implement - previous implementation is missing
- throw new RuntimeException("Not implemented yet.");
- } else if (inst instanceof QuantileSortCPInstruction) {
- // NOTE: mergesort since comparator used
- long m = input.getM();
- double sortCosts = 0;
- if(inst.input2 == null)
- sortCosts = DEFAULT_NFLOP_CP * m + m;
- else //w/ weights
- sortCosts = DEFAULT_NFLOP_CP * (input.isSparse() ? m * input.getS() : m);
-
- return sortCosts + m*(int)(Math.log(m)/Math.log(2)) + // mergesort
- DEFAULT_NFLOP_CP * m;
- } else if (inst instanceof DnnCPInstruction) {
- // TODO: implement the cost function for this
- throw new RuntimeException("Not implemented yet.");
- }
- // NOTE: the upper cases should consider all possible scenarios for unary instructions
- throw new DMLRuntimeException("Attempt for costing unsupported unary instruction.");
- }
-
- private double getNFLOP_CPBinaryInst(BinaryCPInstruction inst, VarStats input1, VarStats input2, VarStats output) throws CostEstimationException {
- if (inst instanceof AppendCPInstruction) {
- return DEFAULT_NFLOP_CP*input1.getCellsWithSparsity()*input2.getCellsWithSparsity();
- } else if (inst instanceof AggregateBinaryCPInstruction) { // ba+*
- // TODO: formula correct?
- // NOTE: reduction by factor 2 because matrix mult better than average flop count (2*x/2 = x)
- if (!input1.isSparse() && !input2.isSparse())
- return input1.getCells() * (input2.getN()>1? input1.getS() : 1.0) * input2.getN();
- else if (input1.isSparse() && !input2.isSparse())
- return input1.getCells() * input1.getS() * input2.getN();
- return input1.getCells() * input1.getS() * input2.getN() * input2.getS();
- } else if (inst instanceof CovarianceCPInstruction) { // cov
- // NOTE: output always scalar, input 3 used as weights block if(allExists)
- // same runtime for 2 and 3 inputs
- return 23 * input1.getM(); //(11+3*k+)
- } else if (inst instanceof QuantilePickCPInstruction) {
- // TODO: implement - previous implementation is missing
- throw new RuntimeException("Not implemented yet.");
+ double computeTime = varToCollect.rddStats.cost;
+ double collectTime;
+ if (OptimizerUtils.checkSparkCollectMemoryBudget(varToCollect.characteristics, freeLocalMemory, false)) {
+ // use Spark collect()
+ collectTime = IOCostUtils.getSparkCollectTime(varToCollect.rddStats, driverMetrics, executorMetrics);
} else {
- // TODO: Make sure no other cases of BinaryCPInstruction exist than the mentioned below
- // NOTE: the case for BinaryScalarScalarCPInstruction,
- // BinaryMatrixScalarCPInstruction,
- // BinaryMatrixMatrixCPInstruction,
- // BinaryFrameMatrixCPInstruction,
- // BinaryFrameFrameCPInstruction,
- String opcode = inst.getOpcode();
- if( opcode.equals("+") || opcode.equals("-") //sparse safe
- && (input1.isSparse() || input2.isSparse()))
- return input1.getCellsWithSparsity() + input2.getCellsWithSparsity();
- else if( opcode.equals("solve") ) //see also MultiReturnBuiltin
- return input1.getCells() * input1.getN(); //for 1kx1k ~ 1GFLOP -> 0.5s
- else
- return output.getCells();
+ // redirect through HDFS (writing to HDFS on executors and reading back on driver)
+ varToCollect.fileInfo = new Object[] {HDFS_SOURCE_IDENTIFIER, FileFormat.BINARY};
+ collectTime = IOCostUtils.getHadoopWriteTime(varToCollect, executorMetrics) +
+ IOCostUtils.getFileSystemReadTime(varToCollect, driverMetrics);
}
- }
-
- private double getNFLOP_CPMatrixBuiltinNaryInst(MatrixBuiltinNaryCPInstruction inst, int numMatrices, VarStats output) throws CostEstimationException {
- String opcode = inst.getOpcode();
- switch (opcode) {
- case "nmin": case "nmax": case "n+": // for max, min plus num of cells for each matrix
- return numMatrices * output.getCells();
- case "rbind": case "cbind":
- return output.getCells();
- default:
- throw new DMLRuntimeException("Unknown opcode: "+opcode);
- }
- }
-
- private double getNFLOP_CPMultiReturnBuiltinInst(MultiReturnBuiltinCPInstruction inst, VarStats input) throws CostEstimationException {
- String opcode = inst.getOpcode();
- // NOTE: they all have cubic complexity, the scaling factor refers to commons.math
- double xf = 2; //default e.g, qr
- switch (opcode) {
- case "eigen":
- xf = 32;
- break;
- case "lu":
- xf = 16;
- break;
- case "svd":
- xf = 32; // TODO - assuming worst case for now
- break;
- }
- return xf * input.getCells() * input.getN(); //for 1kx1k ~ 2GFLOP -> 1s
- }
-
- private double getNFLOP_CPParameterizedBuiltinInst(ParameterizedBuiltinCPInstruction inst, VarStats input, VarStats output) throws CostEstimationException {
- String opcode = inst.getOpcode();
- if(opcode.equals("cdf") || opcode.equals("invcdf"))
- return DEFAULT_NFLOP_UNKNOWN; //scalar call to commons.math
- else if( opcode.equals("groupedagg") ){
- HashMap<String,String> paramsMap = inst.getParameterMap();
- String fn = paramsMap.get("fn");
- String order = paramsMap.get("order");
- CMOperator.AggregateOperationTypes type = CMOperator.getAggOpType(fn, order);
- int attr = type.ordinal();
- double xga = 1;
- switch(attr) {
- case 0: xga=4; break; //sum, see uk+
- case 1: xga=1; break; //count, see cm
- case 2: xga=8; break; //mean
- case 3: xga=16; break; //cm2
- case 4: xga=31; break; //cm3
- case 5: xga=51; break; //cm4
- case 6: xga=16; break; //variance
- }
- return 2 * input.getM() + xga * input.getM(); //scan for min/max, groupedagg
- }
- else if(opcode.equals("rmempty")){
- HashMap<String,String> paramsMap = inst.getParameterMap();
- int attr = paramsMap.get("margin").equals("rows")?0:1;
- switch(attr){
- case 0: //remove rows
- // TODO: Copied from old implementation but maybe reverse the cases?
- return ((input.isSparse()) ? input.getM() : input.getM() * Math.ceil(1.0d/input.getS())/2) +
- DEFAULT_NFLOP_CP * output.getCells();
- case 1: //remove cols
- return input.getN() * Math.ceil(1.0d/input.getS())/2 +
- DEFAULT_NFLOP_CP * output.getCells();
- default:
- throw new DMLRuntimeException("Invalid margin type for opcode "+opcode+".");
- }
-
+ if (varToCollect.rddStats.checkpoint) {
+ varToCollect.rddStats.cost = 0;
} else {
- System.out.println("Estimation for operation "+opcode+" is not supported yet.");
- return 1;
+ varToCollect.rddStats = null;
}
+
+ if (computeTime < 0 || collectTime < 0) {
+ // detection for functionality bugs
+ throw new RuntimeException("Unexpected negative value at estimating Spark Job execution time");
+ }
+ return computeTime + computeTime;
}
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ // Helpers for handling stats and estimating time related to their corresponding variables //
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
/**
- * Intended to be used to get the NFLOP for SPInstructions.
- * 'parse' because the cost of each instruction is to be
- * collected and the cost is to be computed at the end based on
- * all Spark instructions
- * @param inst
- * @return
+ * This method emulates the SystemDS mechanism of loading objects into
+ * the CP memory from a file or an existing RDD object.
+ *
+ * @param input variable for loading in CP memory
+ * @return estimated time in seconds for loading into memory
*/
- @SuppressWarnings("unused")
- protected double parseSPInst(SPInstruction inst) {
- // declare resource-dependant metrics
- double localCost = 0; // [nflop] cost for computing executed in executors
- double globalCost = 0; // [nflop] cost for computing executed in driver
- double IOCost = 0; // [s] cost for shuffling data and writing/reading to HDFS/S3
- // TODO: consider the case of matrix with dims=1
- // NOTE: consider if is needed to include the cost for final aggregation within the Spark Driver (CP)
- if (inst instanceof AggregateTernarySPInstruction) {
- // TODO: need to have a way to associate mVars from _stats with a
- // potentially existing virtual PairRDD - MatrixObject
- // NOTE: leave it for later once I figure out how to do it for unary instructions
- } else if (inst instanceof AggregateUnarySPInstruction) {
- AggregateUnarySPInstruction currentInst = (AggregateUnarySPInstruction) inst;
- if (currentInst.input1.isTensor())
- throw new DMLRuntimeException("CostEstimator does not support tensor input.");
- String opcode = currentInst.getOpcode();
- AggBinaryOp.SparkAggType aggType = currentInst.getAggType();
- AggregateUnaryOperator op = (AggregateUnaryOperator) currentInst.getOperator();
- VarStats input = _stats.get(currentInst.input1.getName());
- RDDStats inputRDD = input._rdd;
- RDDStats currentRDD = inputRDD;
- VarStats outputStats = _stats.get(currentInst.output.getName());
+ private double loadCPVarStatsAndEstimateTime(VarStats input) throws CostEstimationException {
+ if (input.isScalar() || input.allocatedMemory > 0) return 0.0;
- int k = getComputationFactorUAOp(opcode);
- // TODO: RRDstats extra required to keep at least the number of
- // blocks that each next operator operates on: e.g. filter (and mapByKey) is reducing this number,
- // probably better to create and store only intermediate RDDstats shared between instructions
- // since only these are needed for retrieving only intra instructions
- // TODO: later think of how to handle getting null for stats
- if (inputRDD == null) {
- throw new DMLRuntimeException("RDD stats should have been already initiated");
+ double loadTime;
+ // input.fileInfo != null output of reblock inst. -> execution not triggered
+ // input.rddStats.checkpoint for output of checkpoint inst. -> execution not triggered
+ if (input.rddStats != null && (input.fileInfo == null || !input.rddStats.checkpoint)) {
+ // loading from RDD
+ loadTime = getTimeEstimateSparkJob(input);
+ } else {
+ // loading from a file
+ if (input.fileInfo == null || input.fileInfo.length != 2) {
+ throw new DMLRuntimeException("Time estimation is not possible without file info.");
+ } else if (!input.fileInfo[0].equals(HDFS_SOURCE_IDENTIFIER) && !input.fileInfo[0].equals(S3_SOURCE_IDENTIFIER)) {
+ throw new DMLRuntimeException("Time estimation is not possible for data source: " + input.fileInfo[0]);
}
- if (opcode.equals("uaktrace")) {
- // add cost for filter op
- localCost += currentRDD.numBlocks;
- currentRDD = RDDStats.transformNumBlocks(currentRDD, currentRDD.rlen); // only the diagonal blocks left
- }
- if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
- if (op.sparseSafe) {
- localCost += currentRDD.numBlocks; // filter cost
- // TODO: decide how to reduce numBlocks
- }
- localCost += k*currentRDD.numValues*currentRDD.sparsity; // map cost
- // next op is fold -> end of the current Job
- // end of Job -> no need to assign the currentRDD to the output (output is no RDD)
- localCost += currentRDD.numBlocks; // local folding cost
- // TODO: shuffle cost to bring all pairs to the driver (CP)
- // NOTE: neglect the added CP compute cost for folding the distributed aggregates
- } else if (aggType == AggBinaryOp.SparkAggType.MULTI_BLOCK){
- localCost += k*currentRDD.numValues*currentRDD.sparsity; // mapToPair cost
- // NOTE: the new unique number of keys should be
- // next op is combineByKey -> new stage
- localCost += currentRDD.numBlocks + currentRDD.numPartitions; // local merging * merging partitions
- if (op.aggOp.existsCorrection())
- localCost += currentRDD.numBlocks; // mapValues cost for the correction
- } else { // aggType == AggBinaryOp.SparkAggType.NONE
- localCost += k*currentRDD.numValues*currentRDD.sparsity;
- // no reshuffling -> inst is packed with the next spark operation
- }
-
- return globalCost;
- } else if (inst instanceof RandSPInstruction) {
- RandSPInstruction randInst = (RandSPInstruction) inst;
- String opcode = randInst.getOpcode();
- VarStats output = _stats.get(randInst.output.getName());
- // NOTE: update sparsity here
- output._mc.setNonZeros((long) (output.getCells()*randInst.getSparsity()));
- RDDStats outputRDD = new RDDStats(output);
-
- int complexityFactor = 1;
- switch (opcode.toLowerCase()) {
- case DataGen.RAND_OPCODE:
- complexityFactor = 32; // higher complexity for random number generation
- case DataGen.SEQ_OPCODE:
- // first op. from the new stage: parallelize/read from scratch file
- globalCost += complexityFactor*outputRDD.numBlocks; // cp random number generation
- if (outputRDD.numBlocks < RandSPInstruction.INMEMORY_NUMBLOCKS_THRESHOLD) {
- long parBlockSize = MatrixBlock.estimateSizeDenseInMemory(outputRDD.numBlocks, 1);
- IOCost += IOCostUtils.getSparkTransmissionCost(parBlockSize, outputRDD.numParallelTasks);
- } else {
- IOCost += IOCostUtils.getWriteTime(outputRDD.numBlocks, 1, 1.0, HDFS_SOURCE_IDENTIFIER, TEXT); // driver writes down
- IOCost += IOCostUtils.getReadTime(outputRDD.numBlocks, 1, 1.0, HDFS_SOURCE_IDENTIFIER, TEXT); // executors read
- localCost += outputRDD.numBlocks; // mapToPair cost
- }
- localCost += complexityFactor*outputRDD.numValues; // mapToPair cost
- output._rdd = outputRDD;
- return globalCost;
- case DataGen.SAMPLE_OPCODE:
- // first op. from the new stage: parallelize
- complexityFactor = 32; // TODO: set realistic factor
- globalCost += complexityFactor*outputRDD.numPartitions; // cp random number generation
- long parBlockSize = MatrixBlock.estimateSizeDenseInMemory(outputRDD.numPartitions, 1);
- IOCost += IOCostUtils.getSparkTransmissionCost(parBlockSize, outputRDD.numParallelTasks);
- localCost += outputRDD.numBlocks; // flatMap cost
- localCost += complexityFactor*outputRDD.numBlocks; // mapToPairCost cost
- // sortByKey -> new stage
- long randBlockSize = MatrixBlock.estimateSizeDenseInMemory(outputRDD.numBlocks, 1);
- IOCost += IOCostUtils.getShuffleCost(randBlockSize, outputRDD.numParallelTasks);
- localCost += outputRDD.numValues;
- // sortByKey -> shuffling?
- case DataGen.FRAME_OPCODE:
- }
- return globalCost;
+ loadTime = IOCostUtils.getFileSystemReadTime(input, driverMetrics);
}
-
- throw new DMLRuntimeException("Unsupported instruction: " + inst.getOpcode());
- }
-
- /**
- * Intended to handle RDDStats retrievals so the I/O
- * can be computed correctly. This method should also
- * reserve the necessary memory for the RDD on the executors.
- * @param var
- * @param outputRDD
- * @return
- */
- @SuppressWarnings("unused")
- private double getRDDHandleAndEstimateTime(VarStats var, RDDStats outputRDD) {
- double ret = 0;
- if (var._rdd == null) {
- RDDStats newRDD = new RDDStats(var);
- if (var._memory >= 0) { // dirty or cached
- if (!_parRDDs.reserve(newRDD.totalSize)) {
- if (var._dirty) {
- ret += IOCostUtils.getWriteTime(var.getM(), var.getN(), var.getS(), HDFS_SOURCE_IDENTIFIER, BINARY);
- // TODO: think when to set it to true
- var._dirty = false;
- }
- ret += IOCostUtils.getReadTime(var.getM(), var.getN(), var.getS(), HDFS_SOURCE_IDENTIFIER, BINARY) / newRDD.numParallelTasks;
- } else {
- ret += IOCostUtils.getSparkTransmissionCost(newRDD.totalSize, newRDD.numParallelTasks);
- }
- } else { // on hdfs
- if (var._fileInfo == null || var._fileInfo.length != 2)
- throw new DMLRuntimeException("File info missing for a file to be read on Spark.");
- ret += IOCostUtils.getReadTime(var.getM(), var.getN(), var.getS(), (String)var._fileInfo[0], (Types.FileFormat) var._fileInfo[1]) / newRDD.numParallelTasks;
- var._dirty = false; // possibly redundant
- }
- var._rdd = newRDD;
- }
- // if RDD handle exists -> no additional cost to add
- outputRDD = var._rdd;
- return ret;
- }
-
- /////////////////////
- // I/O Costs //
- /////////////////////
-
- private double getLoadTime(VarStats input) throws CostEstimationException {
- if (input == null || input._memory > 0) return 0.0; // input == null marks scalars
- // loading from RDD
- if (input._rdd != null) {
- if (OptimizerUtils.checkSparkCollectMemoryBudget(input._mc, usedMememory, false)) { // .collect()
- long sizeEstimate = OptimizerUtils.estimatePartitionedSizeExactSparsity(input._mc);
- putInMemory(input);
- return IOCostUtils.getSparkTransmissionCost(sizeEstimate, input._rdd.numParallelTasks);
- } else { // redirect through HDFS
- putInMemory(input);
- return IOCostUtils.getWriteTime(input.getM(), input.getN(), input.getS(), HDFS_SOURCE_IDENTIFIER, null) / input._rdd.numParallelTasks +
- IOCostUtils.getReadTime(input.getM(), input.getN(), input.getS(), HDFS_SOURCE_IDENTIFIER, null); // cost for writting to HDFS on executors and reading back on driver
- }
- }
- // loading from a file
- if (input._fileInfo == null || input._fileInfo.length != 2) {
- return 1;
- }
- else if (!input._fileInfo[0].equals(HDFS_SOURCE_IDENTIFIER) && !input._fileInfo[0].equals(S3_SOURCE_IDENTIFIER)) {
- throw new DMLRuntimeException("Time estimation is not possible for data source: "+ input._fileInfo[0]);
- }
+ input.allocatedMemory = OptimizerUtils.estimateSizeExactSparsity(input.characteristics);
putInMemory(input);
- return IOCostUtils.getReadTime(input.getM(), input.getN(), input.getS(), (String) input._fileInfo[0], (Types.FileFormat) input._fileInfo[1]);
+ return loadTime;
}
- private void putInMemory(VarStats input) throws CostEstimationException {
- if(input == null)
- return;
- long sizeEstimate = OptimizerUtils.estimateSize(input._mc);
- if (sizeEstimate + usedMememory > localMemory)
+ private void putInMemory(VarStats output) throws CostEstimationException {
+ if (output.isScalar() || output.allocatedMemory <= MIN_MEMORY_TO_TRACK) return;
+ if (freeLocalMemory - output.allocatedMemory < 0)
throw new CostEstimationException("Insufficient local memory");
- usedMememory += sizeEstimate;
- input._memory = sizeEstimate;
+ freeLocalMemory -= output.allocatedMemory;
}
private void removeFromMemory(VarStats input) {
- if (input == null) return; // for scalars
- usedMememory -= input._memory;
- input._memory = -1;
- }
- /////////////////////
- // HELPERS //
- /////////////////////
-
- private static int getComputationFactorUAOp(String opcode) {
- switch (opcode) {
- case "uatrace": case "uaktrace":
- return 2;
- case "uak+": case "uark+": case "uack+":
- return 4; // 1*k+
- case "uasqk+": case "uarsqk+": case "uacsqk+":
- return 5; // +1 for multiplication to square term
- case "uamean": case "uarmean": case "uacmean":
- return 7; // 1*k+
- case "uavar": case "uarvar": case "uacvar":
- return 14;
- default:
- return 1;
+ if (input == null) return; // scalars or variables never put in memory
+ if (!input.isScalar() && input.allocatedMemory > MIN_MEMORY_TO_TRACK) {
+ freeLocalMemory += input.allocatedMemory;
+ if (freeLocalMemory > localMemoryLimit) {
+ // detection of functionality bugs
+ throw new RuntimeException("Unexpectedly large amount of freed CP memory");
+ }
}
+ if (input.rddStats != null) {
+ input.rddStats = null;
+ }
+ input.allocatedMemory = -1;
+ }
+
+ /**
+ * This method serves a main rule at the mechanism for
+ * estimation the execution time of Spark instructions:
+ * it estimates the time for distributing existing CP variable
+ * or sets the estimated time as time needed for computing the
+ * input variable on Spark.
+ * @param input input statistics
+ * @return time (seconds) for loading the corresponding variable
+ */
+ private double loadRDDStatsAndEstimateTime(VarStats input) {
+ if (input.isScalar()) return 0.0;
+
+ double ret;
+ if (input.rddStats == null) { // rdd is to be distributed by the CP
+ input.rddStats = new RDDStats(input);
+ RDDStats inputRDD = input.rddStats;
+ if (input.allocatedMemory >= 0) { // generated object locally
+ if (inputRDD.distributedSize < freeLocalMemory && inputRDD.distributedSize < (0.1 * localMemoryLimit)) {
+ // in this case transfer the data object over HDF (first set the fileInfo of the input)
+ input.fileInfo = new Object[] {HDFS_SOURCE_IDENTIFIER, FileFormat.BINARY};
+ ret = IOCostUtils.getFileSystemWriteTime(input, driverMetrics);
+ ret += IOCostUtils.getHadoopReadTime(input, executorMetrics);
+ } else {
+ ret = IOCostUtils.getSparkParallelizeTime(inputRDD, driverMetrics, executorMetrics);
+ }
+ } else { // on hdfs
+ if (input.fileInfo == null || input.fileInfo.length != 2)
+ throw new RuntimeException("File info missing for a file to be read on Spark.");
+ ret = IOCostUtils.getHadoopReadTime(input, executorMetrics);
+ }
+ } else if (input.rddStats.distributedSize > 0) {
+ // if input RDD size is initiated -> cost should be calculated
+ // transfer the cost to the output rdd for lineage proper handling
+ ret = input.rddStats.cost;
+ if (input.rddStats.checkpoint) {
+ // cost of checkpoint var transferred only once
+ input.rddStats.cost = 0;
+ }
+ } else {
+ throw new RuntimeException("Initialized RDD stats without initialized data characteristics is undefined behaviour");
+ }
+ return ret;
+ }
+
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ // Generic non-static helpers/utility methods //
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ private VarStats getCTableDim(CPOperand dimOperand) {
+ VarStats dimStats;
+ if (dimOperand.isLiteral()) {
+ dimStats = new VarStats(dimOperand.getLiteral().toString(), null);
+ } else {
+ dimStats = getStatsWithDefaultScalar(dimOperand.getName());
+ }
+ return dimStats;
+ }
+
+ private VarStats getParameterizedBuiltinParamStats(String key, HashMap<String, String> params, boolean required) {
+ String varName = params.get(key);
+ if (required && varName == null) {
+ throw new RuntimeException("ParameterizedBuiltin operation is missing required parameter object for key " + key);
+ } else if (varName == null) {
+ return null;
+ }
+ return getStatsWithDefaultScalar(varName);
}
}
diff --git a/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java b/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java
index 81913dd..a04030c 100644
--- a/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java
+++ b/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java
@@ -20,187 +20,528 @@
package org.apache.sysds.resource.cost;
import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.resource.CloudInstance;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
public class IOCostUtils {
- // NOTE: this class does NOT include methods for estimating IO time
- // for operation ot the local file system since they are not relevant at the moment
+
+ private static final double READ_DENSE_FACTOR = 0.5;
+ private static final double WRITE_DENSE_FACTOR = 0.3;
+ private static final double SPARSE_FACTOR = 0.5;
+ private static final double TEXT_FACTOR = 0.3;
+ // NOTE: skip using such factors for now
+ // private static final double WRITE_MEMORY_FACTOR = 0.9;
+ // private static final double WRITE_DISK_FACTOR = 0.5;
+ private static final double SERIALIZATION_FACTOR = 0.5;
+ private static final double DESERIALIZATION_FACTOR = 0.8;
+ public static final long DEFAULT_FLOPS = 2L * 1024 * 1024 * 1024; // 2 gFLOPS
+
+ public static class IOMetrics {
+ // FLOPS value not directly related to I/O metrics,
+ // but it is not worth it to store it separately
+ long cpuFLOPS;
+ int cpuCores;
+ // All metrics here use MB/s bandwidth unit
+ // Metrics for disk I/O operations
+ double localDiskReadBandwidth;
+ double localDiskWriteBandwidth;
+ double hdfsReadBinaryDenseBandwidth;
+ double hdfsReadBinarySparseBandwidth;
+ double hdfsWriteBinaryDenseBandwidth;
+ double hdfsWriteBinarySparseBandwidth;
+ double hdfsReadTextDenseBandwidth;
+ double hdfsReadTextSparseBandwidth;
+ double hdfsWriteTextDenseBandwidth;
+ double hdfsWriteTextSparseBandwidth;
+ // no s3 read/write metrics since it will not be used for any intermediate operations
+ double s3ReadTextDenseBandwidth;
+ double s3ReadTextSparseBandwidth;
+ double s3WriteTextDenseBandwidth;
+ double s3WriteTextSparseBandwidth;
+ // Metrics for main memory I/O operations
+ double memReadBandwidth;
+ double memWriteBandwidth;
+ // Metrics for networking operations
+ double networkingBandwidth;
+ // Metrics for (de)serialization
+ double serializationBandwidth;
+ double deserializationBandwidth;
+
+ public IOMetrics(CloudInstance instance) {
+ this(instance.getFLOPS(), instance.getVCPUs(), instance.getMemorySpeed(), instance.getDiskSpeed(), instance.getNetworkSpeed());
+ }
+ public IOMetrics(long flops, int cores, double memorySpeed, double diskSpeed, double networkSpeed) {
+ cpuFLOPS = flops;
+ cpuCores = cores;
+ // Metrics for disk I/O operations
+ localDiskReadBandwidth = diskSpeed;
+ localDiskWriteBandwidth = diskSpeed;
+ // Assume that the HDFS I/O operations is done always by accessing local blocks
+ hdfsReadBinaryDenseBandwidth = diskSpeed * READ_DENSE_FACTOR;
+ hdfsReadBinarySparseBandwidth = hdfsReadBinaryDenseBandwidth * SPARSE_FACTOR;
+ hdfsWriteBinaryDenseBandwidth = diskSpeed * WRITE_DENSE_FACTOR;
+ hdfsWriteBinarySparseBandwidth = hdfsWriteBinaryDenseBandwidth * SPARSE_FACTOR;
+ hdfsReadTextDenseBandwidth = hdfsReadBinaryDenseBandwidth * TEXT_FACTOR;
+ hdfsReadTextSparseBandwidth = hdfsReadBinarySparseBandwidth * TEXT_FACTOR;
+ hdfsWriteTextDenseBandwidth = hdfsWriteBinaryDenseBandwidth * TEXT_FACTOR;
+ hdfsWriteTextSparseBandwidth = hdfsWriteBinarySparseBandwidth * TEXT_FACTOR;
+ s3ReadTextDenseBandwidth = networkingBandwidth * READ_DENSE_FACTOR * TEXT_FACTOR;
+ s3ReadTextSparseBandwidth = s3ReadTextDenseBandwidth * SPARSE_FACTOR;
+ s3WriteTextDenseBandwidth = networkingBandwidth * WRITE_DENSE_FACTOR * TEXT_FACTOR;
+ s3WriteTextSparseBandwidth = s3WriteTextDenseBandwidth * SPARSE_FACTOR;
+ // Metrics for main memory I/O operations
+ memReadBandwidth = memorySpeed;
+ memWriteBandwidth = memorySpeed;
+ // Metrics for networking operations
+ networkingBandwidth = networkSpeed;
+ // Metrics for (de)serialization,
+ double currentFlopsFactor = (double) DEFAULT_FLOPS / cpuFLOPS;
+ serializationBandwidth = memReadBandwidth * SERIALIZATION_FACTOR * currentFlopsFactor;
+ deserializationBandwidth = memWriteBandwidth * DESERIALIZATION_FACTOR * currentFlopsFactor;
+ }
+
+ // ----- Testing default -----
+ public static final int DEFAULT_NUM_CPU_CORES = 8;
+ //IO Read
+ public static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 21328.0; // e.g. DDR4-2666
+ public static final double DEFAULT_MBS_DISK_BANDWIDTH = 600; // e.g. m5.4xlarge, baseline bandwidth: 4750Mbps = 593.75 MB/s
+ public static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 640; // e.g. m5.4xlarge, baseline speed bandwidth: 5Gbps = 640MB/s
+ public static final double DEFAULT_MBS_HDFS_READ_BINARY_DENSE = 150;
+ public static final double DEFAULT_MBS_HDFS_READ_BINARY_SPARSE = 75;
+ public static final double DEFAULT_MBS_S3_READ_TEXT_DENSE = 50;
+ public static final double DEFAULT_MBS_S3_READ_TEXT_SPARSE = 25;
+ public static final double DEFAULT_MBS_HDFS_READ_TEXT_DENSE = 75;
+ public static final double DEFAULT_MBS_HDFS_READ_TEXT_SPARSE = 50;
+ // IO Write
+ public static final double DEFAULT_MBS_HDFS_WRITE_BINARY_DENSE = 120;
+ public static final double DEFAULT_MBS_HDFS_WRITE_BINARY_SPARSE = 60;
+ public static final double DEFAULT_MBS_S3_WRITE_TEXT_DENSE = 30;
+ public static final double DEFAULT_MBS_S3_WRITE_TEXT_SPARSE = 20;
+ public static final double DEFAULT_MBS_HDFS_WRITE_TEXT_DENSE = 40;
+ public static final double DEFAULT_MBS_HDFS_WRITE_TEXT_SPARSE = 30;
+
+ /**
+ * Meant to be used for testing by setting known
+ * default values for each metric
+ */
+ public IOMetrics() {
+ cpuFLOPS = DEFAULT_FLOPS;
+ cpuCores = DEFAULT_NUM_CPU_CORES;
+ // Metrics for disk I/O operations
+ localDiskReadBandwidth = DEFAULT_MBS_DISK_BANDWIDTH;
+ localDiskWriteBandwidth = DEFAULT_MBS_DISK_BANDWIDTH;
+ // Assume that the HDFS I/O operations is done always by accessing local blocks
+ hdfsReadBinaryDenseBandwidth = DEFAULT_MBS_HDFS_READ_BINARY_DENSE;
+ hdfsReadBinarySparseBandwidth = DEFAULT_MBS_HDFS_READ_BINARY_SPARSE;
+ hdfsWriteBinaryDenseBandwidth = DEFAULT_MBS_HDFS_WRITE_BINARY_DENSE;
+ hdfsWriteBinarySparseBandwidth = DEFAULT_MBS_HDFS_WRITE_BINARY_SPARSE;
+ hdfsReadTextDenseBandwidth = DEFAULT_MBS_HDFS_READ_TEXT_DENSE;
+ hdfsReadTextSparseBandwidth = DEFAULT_MBS_HDFS_READ_TEXT_SPARSE;
+ hdfsWriteTextDenseBandwidth = DEFAULT_MBS_HDFS_WRITE_TEXT_DENSE;
+ hdfsWriteTextSparseBandwidth = DEFAULT_MBS_HDFS_WRITE_TEXT_SPARSE;
+ s3ReadTextDenseBandwidth = DEFAULT_MBS_S3_READ_TEXT_DENSE;
+ s3ReadTextSparseBandwidth = DEFAULT_MBS_S3_READ_TEXT_SPARSE;
+ s3WriteTextDenseBandwidth = DEFAULT_MBS_S3_WRITE_TEXT_DENSE;
+ s3WriteTextSparseBandwidth = DEFAULT_MBS_S3_WRITE_TEXT_SPARSE;
+ // Metrics for main memory I/O operations
+ memReadBandwidth = DEFAULT_MBS_MEMORY_BANDWIDTH;
+ memWriteBandwidth = DEFAULT_MBS_MEMORY_BANDWIDTH;
+ // Metrics for networking operations
+ networkingBandwidth = DEFAULT_MBS_NETWORK_BANDWIDTH;
+ // Metrics for (de)serialization,
+ double currentFlopsFactor = (double) DEFAULT_FLOPS / cpuFLOPS;
+ serializationBandwidth = memReadBandwidth * SERIALIZATION_FACTOR * currentFlopsFactor;
+ deserializationBandwidth = memWriteBandwidth * DESERIALIZATION_FACTOR * currentFlopsFactor;
+ }
+ }
+
protected static final String S3_SOURCE_IDENTIFIER = "s3";
protected static final String HDFS_SOURCE_IDENTIFIER = "hdfs";
- //IO READ throughput
- private static final double DEFAULT_MBS_S3READ_BINARYBLOCK_DENSE = 200;
- private static final double DEFAULT_MBS_S3READ_BINARYBLOCK_SPARSE = 100;
- private static final double DEFAULT_MBS_HDFSREAD_BINARYBLOCK_DENSE = 150;
- public static final double DEFAULT_MBS_HDFSREAD_BINARYBLOCK_SPARSE = 75;
- private static final double DEFAULT_MBS_S3READ_TEXT_DENSE = 50;
- private static final double DEFAULT_MBS_S3READ_TEXT_SPARSE = 25;
- private static final double DEFAULT_MBS_HDFSREAD_TEXT_DENSE = 75;
- private static final double DEFAULT_MBS_HDFSREAD_TEXT_SPARSE = 50;
- //IO WRITE throughput
- private static final double DEFAULT_MBS_S3WRITE_BINARYBLOCK_DENSE = 150;
- private static final double DEFAULT_MBS_S3WRITE_BINARYBLOCK_SPARSE = 75;
- private static final double DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_DENSE = 120;
- private static final double DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_SPARSE = 60;
- private static final double DEFAULT_MBS_S3WRITE_TEXT_DENSE = 30;
- private static final double DEFAULT_MBS_S3WRITE_TEXT_SPARSE = 20;
- private static final double DEFAULT_MBS_HDFSWRITE_TEXT_DENSE = 40;
- private static final double DEFAULT_MBS_HDFSWRITE_TEXT_SPARSE = 30;
- // New -> Spark cost estimation
- private static final double DEFAULT_NETWORK_BANDWIDTH = 100; // bandwidth for shuffling data
- //private static final double DEFAULT_DISK_BANDWIDTH = 1000; // bandwidth for shuffling data
- private static final double DEFAULT_NETWORK_LATENCY = 0.001; // latency for data transfer in seconds
- //private static final double DEFAULT_META_TO_DRIVER_MS = 10; // cost in ms to account for the metadata transmitted to the driver at the end of each stage
- private static final double SERIALIZATION_FACTOR = 10; // virtual unit - MB/(GFLOPS*s)
- private static final double MIN_TRANSFER_TIME = 0.001; // 1ms
- private static final double MIN_SERIALIZATION_TIME = 0.001; // 1ms (intended to include serialization and deserialization time)
- private static final double DEFAULT_MBS_MEM_READ_BANDWIDTH = 32000; // TODO: dynamic value later
- private static final double DEFAULT_MBS_MEM_WRITE_BANDWIDTH = 32000; // TODO: dynamic value later
- protected static double getMemReadTime(VarStats stats) {
- if (stats == null) return 0; // scalars
- if (stats._memory < 0)
- return 1;
- long size = stats._memory;
- double sizeMB = (double) size / (1024 * 1024);
-
- return sizeMB / DEFAULT_MBS_MEM_READ_BANDWIDTH;
- }
-
- protected static double getMemWriteTime(VarStats stats) {
- if (stats == null) return 0; // scalars
- if (stats._memory < 0)
- throw new DMLRuntimeException("VarStats should have estimated size before getting write time");
- long size = stats._memory;
- double sizeMB = (double) size / (1024 * 1024);
-
- return sizeMB / DEFAULT_MBS_MEM_WRITE_BANDWIDTH;
- }
/**
- * Returns the estimated read time from HDFS.
- * NOTE: Does not handle unknowns.
+ * Estimate time to scan object in memory in CP.
*
- * @param dm rows?
- * @param dn columns?
- * @param ds sparsity factor?
- * @param source data source (S3 or HDFS)
- * @param format file format (null for binary)
- * @return estimated HDFS read time
+ * @param stats object statistics
+ * @param metrics CP node's metrics
+ * @return estimated time in seconds
*/
- protected static double getReadTime(long dm, long dn, double ds, String source, Types.FileFormat format)
+ public static double getMemReadTime(VarStats stats, IOMetrics metrics) {
+ if (stats.isScalar()) return 0; // scalars
+ if (stats.allocatedMemory < 0)
+ throw new RuntimeException("VarStats.allocatedMemory should carry the estimated size before getting read time");
+ double sizeMB = (double) stats.allocatedMemory / (1024 * 1024);
+ return sizeMB / metrics.memReadBandwidth;
+ }
+
+ /**
+ * Estimate time to scan distributed data sets in memory on Spark.
+ * It integrates a mechanism to account for scanning
+ * spilled-over data sets on the local disk.
+ *
+ * @param stats object statistics
+ * @param metrics CP node's metrics
+ * @return estimated time in seconds
+ */
+ public static double getMemReadTime(RDDStats stats, IOMetrics metrics) {
+ // no scalars expected
+ double size = (double) stats.distributedSize;
+ if (size < 0)
+ throw new RuntimeException("RDDStats.distributedMemory should carry the estimated size before getting read time");
+ // define if/what a fraction is spilled over to disk
+ double minExecutionMemory = SparkExecutionContext.getDataMemoryBudget(true, false); // execution mem = storage mem
+ double spillOverFraction = minExecutionMemory >= size? 0 : (size - minExecutionMemory) / size;
+ // for simplification define an average read bandwidth combination form memory and disk bandwidths
+ double mixedBandwidthPerCore = (spillOverFraction * metrics.localDiskReadBandwidth +
+ (1-spillOverFraction) * metrics.memReadBandwidth) / metrics.cpuCores;
+ double numWaves = Math.ceil((double) stats.numPartitions / SparkExecutionContext.getDefaultParallelism(false));
+ double sizeMB = size / (1024 * 1024);
+ double partitionSizeMB = sizeMB / stats.numPartitions;
+ return numWaves * (partitionSizeMB / mixedBandwidthPerCore);
+ }
+
+ /**
+ * Estimate time to write object to memory in CP.
+ *
+ * @param stats object statistics
+ * @param metrics CP node's metrics
+ * @return estimated time in seconds
+ */
+ public static double getMemWriteTime(VarStats stats, IOMetrics metrics) {
+ if (stats == null) return 0; // scalars
+ if (stats.allocatedMemory < 0)
+ throw new DMLRuntimeException("VarStats.allocatedMemory should carry the estimated size before getting write time");
+ double sizeMB = (double) stats.allocatedMemory / (1024 * 1024);
+
+ return sizeMB / metrics.memWriteBandwidth;
+ }
+
+ /**
+ * Estimate time to write distributed data set on memory in CP.
+ * It does NOT integrate mechanism to account for spill-overs.
+ *
+ * @param stats object statistics
+ * @param metrics CP node's metrics
+ * @return estimated time in seconds
+ */
+ public static double getMemWriteTime(RDDStats stats, IOMetrics metrics) {
+ // no scalars expected
+ if (stats.distributedSize < 0)
+ throw new RuntimeException("RDDStats.distributedMemory should carry the estimated size before getting write time");
+ double numWaves = Math.ceil((double) stats.numPartitions / SparkExecutionContext.getDefaultParallelism(false));
+ double sizeMB = (double) stats.distributedSize / (1024 * 1024);
+ double partitionSizeMB = sizeMB / stats.numPartitions;
+ return numWaves * partitionSizeMB / (metrics.memWriteBandwidth / metrics.cpuCores);
+ }
+
+ /**
+ * Estimates the read time for a file on HDFS or S3 by the Control Program
+ * @param stats stats for the input matrix/object
+ * @param metrics I/O metrics for the driver node
+ * @return estimated time in seconds
+ */
+ public static double getFileSystemReadTime(VarStats stats, IOMetrics metrics) {
+ String sourceType = (String) stats.fileInfo[0];
+ Types.FileFormat format = (Types.FileFormat) stats.fileInfo[1];
+ double sizeMB = getFileSizeInMB(stats);
+ boolean isSparse = MatrixBlock.evalSparseFormatOnDisk(stats.getM(), stats.getN(), stats.getNNZ());
+ return getStorageReadTime(sizeMB, isSparse, sourceType, format, metrics);
+ }
+
+ /**
+ * Estimates the read time for a file on HDFS or S3 by Spark cluster.
+ * It doesn't directly calculate the execution time regarding the object size
+ * but regarding full executor utilization and maximum block size to be read by
+ * an executor core (HDFS block size). The estimated time for "fully utilized"
+ * reading is then multiplied by the slot execution round since even not fully utilized,
+ * the last round should take approximately the same time as if all slots are assigned
+ * to an active reading task.
+ * This function cannot rely on the {@code RDDStats} since they would not be
+ * initialized for the input object.
+ * @param stats stats for the input matrix/object
+ * @param metrics I/O metrics for the executor node
+ * @return estimated time in seconds
+ */
+ public static double getHadoopReadTime(VarStats stats, IOMetrics metrics) {
+ String sourceType = (String) stats.fileInfo[0];
+ Types.FileFormat format = (Types.FileFormat) stats.fileInfo[1];
+ long size = getPartitionedFileSize(stats);
+ // since getDiskReadTime() computes the write time utilizing the whole executor resources
+ // use the fact that <partition size> / <bandwidth per slot> = <partition size> * <slots per executor> / <bandwidth per executor>
+ long hdfsBlockSize = InfrastructureAnalyzer.getHDFSBlockSize();
+ double numPartitions = Math.ceil((double) size / hdfsBlockSize);
+ double sizePerExecutorMB = (double) (metrics.cpuCores * hdfsBlockSize) / (1024*1024);
+ boolean isSparse = MatrixBlock.evalSparseFormatOnDisk(stats.getM(), stats.getN(), stats.getNNZ());
+ double timePerCore = getStorageReadTime(sizePerExecutorMB, isSparse, sourceType, format, metrics); // same as time per executor
+ // number of execution waves (maximum task to execute per core)
+ double numWaves = Math.ceil(numPartitions / (SparkExecutionContext.getNumExecutors() * metrics.cpuCores));
+ return numWaves * timePerCore;
+ }
+
+ private static double getStorageReadTime(double sizeMB, boolean isSparse, String source, Types.FileFormat format, IOMetrics metrics)
{
- boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds*dm*dn));
- double ret = ((double)MatrixBlock.estimateSizeOnDisk(dm, dn, (long)(ds*dm*dn))) / (1024*1024);
-
- if (format == null || !format.isTextFormat()) {
+ double time;
+ // TODO: consider if the text or binary should be default if format == null
+ if (format == null || format.isTextFormat()) {
if (source.equals(S3_SOURCE_IDENTIFIER)) {
- if (sparse)
- ret /= DEFAULT_MBS_S3READ_BINARYBLOCK_SPARSE;
+ if (isSparse)
+ time = sizeMB / metrics.s3ReadTextSparseBandwidth;
+ else // dense
+ time = sizeMB / metrics.s3ReadTextDenseBandwidth;
+ } else { // HDFS
+ if (isSparse)
+ time = sizeMB / metrics.hdfsReadTextSparseBandwidth;
else //dense
- ret /= DEFAULT_MBS_S3READ_BINARYBLOCK_DENSE;
- } else { //HDFS
- if (sparse)
- ret /= DEFAULT_MBS_HDFSREAD_BINARYBLOCK_SPARSE;
- else //dense
- ret /= DEFAULT_MBS_HDFSREAD_BINARYBLOCK_DENSE;
+ time = sizeMB / metrics.hdfsReadTextDenseBandwidth;
}
- } else {
+ } else if (format == Types.FileFormat.BINARY) {
+ if (source.equals(HDFS_SOURCE_IDENTIFIER)) {
+ if (isSparse)
+ time = sizeMB / metrics.hdfsReadBinarySparseBandwidth;
+ else //dense
+ time = sizeMB / metrics.hdfsReadBinaryDenseBandwidth;
+ } else { // S3
+ throw new RuntimeException("Reading binary files from S3 is not supported");
+ }
+ } else { // compressed
+ throw new RuntimeException("Format " + format + " is not supported for estimation yet.");
+ }
+ return time;
+ }
+
+ /**
+ * Estimates the time for writing a file to HDFS or S3.
+ *
+ * @param stats stats for the input matrix/object
+ * @param metrics I/O metrics for the driver node
+ * @return estimated time in seconds
+ */
+ public static double getFileSystemWriteTime(VarStats stats, IOMetrics metrics) {
+ String sourceType = (String) stats.fileInfo[0];
+ Types.FileFormat format = (Types.FileFormat) stats.fileInfo[1];
+ double sizeMB = getFileSizeInMB(stats);
+ boolean isSparse = MatrixBlock.evalSparseFormatOnDisk(stats.getM(), stats.getN(), stats.getNNZ());
+ return getStorageWriteTime(sizeMB, isSparse, sourceType, format, metrics);
+ }
+
+ /**
+ * Estimates the write time for a file on HDFS or S3 by Spark cluster.
+ * Follows the same logic as {@code getHadoopReadTime}, but here
+ * it can be relied on the {@code RDDStats} since the input object
+ * should be initialized by the prior instruction
+ * @param stats stats for the input matrix/object
+ * @param metrics I/O metrics for the executor node
+ * @return estimated time in seconds
+ */
+ public static double getHadoopWriteTime(VarStats stats, IOMetrics metrics) {
+ if (stats.rddStats == null) {
+ throw new RuntimeException("Estimation for hadoop write time required VarStats object with assigned 'rddStats' member");
+ }
+ String sourceType = (String) stats.fileInfo[0];
+ Types.FileFormat format = (Types.FileFormat) stats.fileInfo[1];
+ long size = getPartitionedFileSize(stats);
+ // time = <num. waves> * <partition size> / <bandwidth per slot>
+ // here it cannot be assumed that the partition size is equal to the HDFS block size
+ double sizePerPartitionMB = (double) size / stats.rddStats.numPartitions / (1024*1024);
+ // since getDiskWriteTime() computes the write time utilizing the whole executor resources
+ // use the fact that <partition size> / <bandwidth per slot> = <partition size> * <slots per executor> / <bandwidth per executor>
+ double sizePerExecutor = sizePerPartitionMB * metrics.cpuCores;
+ boolean isSparse = MatrixBlock.evalSparseFormatOnDisk(stats.getM(), stats.getN(), stats.getNNZ());
+ double timePerCore = getStorageWriteTime(sizePerExecutor, isSparse, sourceType, format, metrics); // same as time per executor
+ // number of execution waves (maximum task to execute per core)
+ double numWaves = Math.ceil((double) stats.rddStats.numPartitions /
+ (SparkExecutionContext.getNumExecutors() * metrics.cpuCores));
+ return numWaves * timePerCore;
+ }
+
+ protected static double getStorageWriteTime(double sizeMB, boolean isSparse, String source, Types.FileFormat format, IOMetrics metrics) {
+ if (format == null || !(source.equals(HDFS_SOURCE_IDENTIFIER) || source.equals(S3_SOURCE_IDENTIFIER))) {
+ throw new RuntimeException("Estimation not possible without source identifier and file format");
+ }
+ double time;
+ if (format.isTextFormat()) {
if (source.equals(S3_SOURCE_IDENTIFIER)) {
- if (sparse)
- ret /= DEFAULT_MBS_S3READ_TEXT_SPARSE;
+ if (isSparse)
+ time = sizeMB / metrics.s3WriteTextSparseBandwidth;
+ else // dense
+ time = sizeMB / metrics.s3WriteTextDenseBandwidth;
+ } else { // HDFS
+ if (isSparse)
+ time = sizeMB / metrics.hdfsWriteTextSparseBandwidth;
else //dense
- ret /= DEFAULT_MBS_S3READ_TEXT_DENSE;
- } else { //HDFS
- if (sparse)
- ret /= DEFAULT_MBS_HDFSREAD_TEXT_SPARSE;
- else //dense
- ret /= DEFAULT_MBS_HDFSREAD_TEXT_DENSE;
+ time = sizeMB / metrics.hdfsWriteTextDenseBandwidth;
}
+ } else if (format == Types.FileFormat.BINARY) {
+ if (source.equals(HDFS_SOURCE_IDENTIFIER)) {
+ if (isSparse)
+ time = sizeMB / metrics.hdfsWriteBinarySparseBandwidth;
+ else //dense
+ time = sizeMB / metrics.hdfsWriteBinaryDenseBandwidth;
+ } else { // S3
+ throw new RuntimeException("Writing binary files from S3 is not supported");
+ }
+ } else { // compressed
+ throw new RuntimeException("Format " + format + " is not supported for estimation yet.");
}
- return ret;
+ return time;
}
- protected static double getWriteTime(long dm, long dn, double ds, String source, Types.FileFormat format) {
- boolean sparse = MatrixBlock.evalSparseFormatOnDisk(dm, dn, (long)(ds*dm*dn));
- double bytes = MatrixBlock.estimateSizeOnDisk(dm, dn, (long)(ds*dm*dn));
- double mbytes = bytes / (1024*1024);
- double ret;
+ /**
+ * Estimates the time ro parallelize a local object to Spark.
+ *
+ * @param output RDD statistics for the object to be collected/transferred.
+ * @param driverMetrics I/O metrics for the receiver - driver node
+ * @param executorMetrics I/O metrics for the executor nodes
+ * @return estimated time in seconds
+ */
+ public static double getSparkParallelizeTime(RDDStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) {
+ // TODO: ensure the object related to stats is read in memory already ot add logic to account for its read time
+ // it is assumed that the RDD object is already created/read
+ // general idea: time = <serialization time> + <transfer time>;
+ // NOTE: currently it is assumed that ht serialized data has the same size as the original data what may not be true in the general case
+ double sizeMB = (double) output.distributedSize / (1024 * 1024);
+ // 1. serialization time
+ double serializationTime = sizeMB / driverMetrics.serializationBandwidth;
+ // 2. transfer time
+ double effectiveBandwidth = Math.min(driverMetrics.networkingBandwidth,
+ SparkExecutionContext.getNumExecutors() * executorMetrics.networkingBandwidth);
+ double transferTime = sizeMB / effectiveBandwidth;
+ // sum the time for the steps since they cannot overlap
+ return serializationTime + transferTime;
+ }
- if (source == S3_SOURCE_IDENTIFIER) {
- if (format.isTextFormat()) {
- if (sparse)
- ret = mbytes / DEFAULT_MBS_S3WRITE_TEXT_SPARSE;
- else //dense
- ret = mbytes / DEFAULT_MBS_S3WRITE_TEXT_DENSE;
- ret *= 2.75; //text commonly 2x-3.5x larger than binary
- } else {
- if (sparse)
- ret = mbytes / DEFAULT_MBS_S3WRITE_BINARYBLOCK_SPARSE;
- else //dense
- ret = mbytes / DEFAULT_MBS_S3WRITE_BINARYBLOCK_DENSE;
- }
- } else { //HDFS
- if (format.isTextFormat()) {
- if (sparse)
- ret = mbytes / DEFAULT_MBS_HDFSWRITE_TEXT_SPARSE;
- else //dense
- ret = mbytes / DEFAULT_MBS_HDFSWRITE_TEXT_DENSE;
- ret *= 2.75; //text commonly 2x-3.5x larger than binary
- } else {
- if (sparse)
- ret = mbytes / DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_SPARSE;
- else //dense
- ret = mbytes / DEFAULT_MBS_HDFSWRITE_BINARYBLOCK_DENSE;
- }
+ /**
+ * Estimates the time for collecting Spark Job output;
+ * The output RDD is transferred to the Spark driver at the end of each ResultStage;
+ * time = transfer time (overlaps and dominates the read and deserialization times);
+ *
+ * @param output RDD statistics for the object to be collected/transferred.
+ * @param driverMetrics I/O metrics for the receiver - driver node
+ * @param executorMetrics I/O metrics for the executor nodes
+ * @return estimated time in seconds
+ */
+ public static double getSparkCollectTime(RDDStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) {
+ double sizeMB = (double) output.distributedSize / (1024 * 1024);
+ double numWaves = Math.ceil((double) output.numPartitions / SparkExecutionContext.getDefaultParallelism(false));
+ int currentParallelism = Math.min(output.numPartitions, SparkExecutionContext.getDefaultParallelism(false));
+ double bandwidthPerCore = executorMetrics.networkingBandwidth / executorMetrics.cpuCores;
+ double effectiveBandwidth = Math.min(numWaves * driverMetrics.networkingBandwidth,
+ currentParallelism * bandwidthPerCore);
+ // transfer time
+ return sizeMB / effectiveBandwidth;
+ }
+
+ /**
+ * Estimates the time for reading distributed RDD input at the beginning of a Stage;
+ * time = transfer time (overlaps and dominates the read and deserialization times);
+ * For simplification it is assumed that the whole dataset is shuffled;
+ *
+ * @param input RDD statistics for the object to be shuffled at the begging of a Stage.
+ * @param metrics I/O metrics for the executor nodes
+ * @return estimated time in seconds
+ */
+ public static double getSparkShuffleReadTime(RDDStats input, IOMetrics metrics) {
+ double sizeMB = (double) input.distributedSize / (1024 * 1024);
+ // edge case: 1 executor only would not trigger any data
+ if (SparkExecutionContext.getNumExecutors() < 2) {
+ // even without shuffling the data needs to be read from the intermediate shuffle files
+ double diskBandwidthPerCore = metrics.localDiskWriteBandwidth / metrics.cpuCores;
+ // disk read time
+ return sizeMB / diskBandwidthPerCore;
}
- return ret;
+ int currentParallelism = Math.min(input.numPartitions, SparkExecutionContext.getDefaultParallelism(false));
+ double networkBandwidthPerCore = metrics.networkingBandwidth / metrics.cpuCores;
+ // transfer time
+ return sizeMB / (currentParallelism * networkBandwidthPerCore);
}
/**
- * Returns the estimated cost for transmitting a packet of size bytes.
- * This function is supposed to be used for parallelize and result data transfer.
- * Driver <-> Executors interaction.
- * @param size
- * @param numExecutors
- * @return
+ * Estimates the time for reading distributed RDD input at the beginning of a Stage
+ * when a wide-transformation is partition preserving: only local disk reads
+ *
+ * @param input RDD statistics for the object to be shuffled (read) at the begging of a Stage.
+ * @param metrics I/O metrics for the executor nodes
+ * @return estimated time in seconds
*/
- protected static double getSparkTransmissionCost(long size, int numExecutors) {
- double transferTime = Math.max(((double) size / (DEFAULT_NETWORK_BANDWIDTH * numExecutors)), MIN_TRANSFER_TIME);
- double serializationTime = Math.max((size * SERIALIZATION_FACTOR) / CostEstimator.CP_FLOPS, MIN_SERIALIZATION_TIME);
- return DEFAULT_NETWORK_LATENCY + transferTime + serializationTime;
+ public static double getSparkShuffleReadStaticTime(RDDStats input, IOMetrics metrics) {
+ double sizeMB = (double) input.distributedSize / (1024 * 1024);
+ int currentParallelism = Math.min(input.numPartitions, SparkExecutionContext.getDefaultParallelism(false));
+ double readBandwidthPerCore = metrics.memReadBandwidth / metrics.cpuCores;
+ // read time
+ return sizeMB / (currentParallelism * readBandwidthPerCore);
}
/**
- * Returns the estimated cost for shuffling the records of an RDD of given size.
- * This function assumes that all the records would be reshuffled what often not the case
- * but this approximation is good enough for estimating the shuffle cost with higher skewness.
- * Executors <-> Executors interaction.
- * @param size
- * @param numExecutors
- * @return
+ * Estimates the time for writing the RDD output to the local system at the end of a ShuffleMapStage;
+ * time = disk write time (overlaps and dominates the serialization time)
+ * The whole data set is being written to shuffle files even if 1 executor is utilized;
+ *
+ * @param output RDD statistics for the output each ShuffleMapStage
+ * @param metrics I/O metrics for the executor nodes
+ * @return estimated time in seconds
*/
- protected static double getShuffleCost(long size, int numExecutors) {
- double transferTime = Math.max(((double) size / (DEFAULT_NETWORK_BANDWIDTH * numExecutors)), MIN_TRANSFER_TIME);
- double serializationTime = Math.max((size * SERIALIZATION_FACTOR) / CostEstimator.SP_FLOPS, MIN_SERIALIZATION_TIME) / numExecutors;
- return DEFAULT_NETWORK_LATENCY * numExecutors + transferTime + serializationTime;
+ public static double getSparkShuffleWriteTime(RDDStats output, IOMetrics metrics) {
+ double sizeMB = (double) output.distributedSize / (1024 * 1024);
+ int currentParallelism = Math.min(output.numPartitions, SparkExecutionContext.getDefaultParallelism(false));
+ double bandwidthPerCore = metrics.localDiskWriteBandwidth / metrics.cpuCores;
+ // disk write time
+ return sizeMB / (currentParallelism * bandwidthPerCore);
}
/**
- * Returns the estimated cost for broadcasting a packet of size bytes.
- * This function takes into account the torrent-like trnasmission of the
- * broadcast data packages.
- * Executors <-> Driver <-> Executors interaction.
- * @param size
- * @param numExecutors
- * @return
+ * Combines the shuffle write and read time since these are being typically
+ * added in one place to the general data transmission for instruction estimation.
+ *
+ * @param output RDD statistics for the output each ShuffleMapStage
+ * @param metrics I/O metrics for the executor nodes
+ * @param withDistribution flag if the data is indeed reshuffled (default case),
+ * false in case of co-partitioned wide-transformation
+ * @return estimated time in seconds
*/
- protected static double getBroadcastCost(long size, int numExecutors) {
- double transferTime = Math.max(((double) size / (DEFAULT_NETWORK_BANDWIDTH)), MIN_TRANSFER_TIME);
- double serializationTime = Math.max((size * SERIALIZATION_FACTOR) / CostEstimator.CP_FLOPS, MIN_SERIALIZATION_TIME);
- return DEFAULT_NETWORK_LATENCY * numExecutors + transferTime + serializationTime;
+ public static double getSparkShuffleTime(RDDStats output, IOMetrics metrics, boolean withDistribution) {
+ double totalTime = getSparkShuffleWriteTime(output, metrics);
+ if (withDistribution)
+ totalTime += getSparkShuffleReadTime(output, metrics);
+ else
+ totalTime += getSparkShuffleReadStaticTime(output, metrics);
+ return totalTime;
}
+ /**
+ * Estimates the time for broadcasting an object;
+ * This function takes into account the torrent-like mechanism
+ * for broadcast distribution across all executors;
+ *
+ * @param stats statistics for the object for broadcasting
+ * @param driverMetrics I/O metrics for the driver node
+ * @param executorMetrics I/O metrics for the executor nodes
+ * @return estimated time in seconds
+ */
+ protected static double getSparkBroadcastTime(VarStats stats, IOMetrics driverMetrics, IOMetrics executorMetrics) {
+ // TODO: ensure the object related to stats is read in memory already ot add logic to account for its read time
+ // it is assumed that the Cp broadcast object is already created/read
+ // general idea: time = <serialization time> + <transfer time>;
+ // NOTE: currently it is assumed that ht serialized data has the same size as the original data what may not be true in the general case
+ double sizeMB = (double) OptimizerUtils.estimatePartitionedSizeExactSparsity(stats.characteristics) / (1024 * 1024);
+ // 1. serialization time
+ double serializationTime = sizeMB / driverMetrics.serializationBandwidth;
+ // 2. transfer time considering the torrent-like mechanism: time to transfer the whole object to a single node
+ double effectiveBandwidth = Math.min(driverMetrics.networkingBandwidth, executorMetrics.networkingBandwidth);
+ double transferTime = sizeMB / effectiveBandwidth;
+ // sum the time for the steps since they cannot overlap
+ return serializationTime + transferTime;
+ }
+
+ /**
+ * Extracts the data source for a given file name: e.g. "hdfs" or "s3"
+ *
+ * @param fileName filename to parse
+ * @return data source type
+ */
public static String getDataSource(String fileName) {
String[] fileParts = fileName.split("://");
if (fileParts.length > 1) {
@@ -208,4 +549,34 @@
}
return HDFS_SOURCE_IDENTIFIER;
}
+
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ // Helpers //
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ private static double getFileSizeInMB(VarStats fileStats) {
+ Types.FileFormat format = (Types.FileFormat) fileStats.fileInfo[1];
+ double sizeMB;
+ if (format == Types.FileFormat.BINARY) {
+ sizeMB = (double) MatrixBlock.estimateSizeOnDisk(fileStats.getM(), fileStats.getM(), fileStats.getNNZ()) / (1024*1024);
+ } else if (format.isTextFormat()) {
+ sizeMB = (double) OptimizerUtils.estimateSizeTextOutput(fileStats.getM(), fileStats.getM(), fileStats.getNNZ(), format) / (1024*1024);
+ } else { // compressed
+ throw new RuntimeException("Format " + format + " is not supported for estimation yet.");
+ }
+ return sizeMB;
+ }
+
+ private static long getPartitionedFileSize(VarStats fileStats) {
+ Types.FileFormat format = (Types.FileFormat) fileStats.fileInfo[1];
+ long size;
+ if (format == Types.FileFormat.BINARY) {
+ size = MatrixBlock.estimateSizeOnDisk(fileStats.getM(), fileStats.getM(), fileStats.getNNZ());
+ } else if (format.isTextFormat()) {
+ size = OptimizerUtils.estimateSizeTextOutput(fileStats.getM(), fileStats.getM(), fileStats.getNNZ(), format);
+ } else { // compressed
+ throw new RuntimeException("Format " + format + " is not supported for estimation yet.");
+ }
+ return size;
+ }
}
diff --git a/src/main/java/org/apache/sysds/resource/cost/RDDStats.java b/src/main/java/org/apache/sysds/resource/cost/RDDStats.java
index 6d5d26d..01ca8f1 100644
--- a/src/main/java/org/apache/sysds/resource/cost/RDDStats.java
+++ b/src/main/java/org/apache/sysds/resource/cost/RDDStats.java
@@ -21,56 +21,98 @@
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
+
public class RDDStats {
- @SuppressWarnings("unused")
- private static int blockSize; // TODO: think of more efficient way (does not changes, init once)
- public long totalSize;
- private static long hdfsBlockSize;
- public long numPartitions;
- public long numBlocks;
- public long numValues;
- public long rlen;
- public long clen;
- public double sparsity;
- public VarStats cpVar;
- public int numParallelTasks;
+ long distributedSize;
+ int numPartitions;
+ boolean hashPartitioned;
+ boolean checkpoint;
+ double cost;
+ boolean isCollected;
- public static void setDefaults() {
- blockSize = ConfigurationManager.getBlocksize();
- hdfsBlockSize = InfrastructureAnalyzer.getHDFSBlockSize();
- }
-
- public RDDStats(VarStats cpVar) {
- totalSize = OptimizerUtils.estimateSizeExactSparsity(cpVar.getM(), cpVar.getN(), cpVar.getS());
- numPartitions = (int) Math.max(Math.min(totalSize / hdfsBlockSize, cpVar._mc.getNumBlocks()), 1);
- numBlocks = cpVar._mc.getNumBlocks();
- this.cpVar = cpVar;
- rlen = cpVar.getM();
- clen = cpVar.getN();
- numValues = rlen*clen;
- sparsity = cpVar.getS();
- numParallelTasks = (int) Math.min(numPartitions, SparkExecutionContext.getDefaultParallelism(false));
- }
-
- public static RDDStats transformNumPartitions(RDDStats oldRDD, long newNumPartitions) {
- if (oldRDD.cpVar == null) {
- throw new DMLRuntimeException("Cannot transform RDDStats without VarStats");
+ /**
+ * Initiates RDD statistics object bound
+ * to an existing {@code VarStats} object.
+ * Uses HDFS block size to adjust automatically the
+ * number of partitions for the current RDD.
+ *
+ * @param sourceStats bound variables statistics
+ */
+ public RDDStats(VarStats sourceStats) {
+ // required cpVar initiated for not scalars
+ if (sourceStats == null) {
+ throw new RuntimeException("RDDStats cannot be initialized without valid input variable statistics");
}
- RDDStats newRDD = new RDDStats(oldRDD.cpVar);
- newRDD.numPartitions = newNumPartitions;
- return newRDD;
+ checkpoint = false;
+ isCollected = false;
+ hashPartitioned = false;
+ // RDD specific characteristics not initialized -> simulates lazy evaluation
+ distributedSize = estimateDistributedSize(sourceStats);
+ numPartitions = getNumPartitions();
+ cost = 0;
}
- public static RDDStats transformNumBlocks(RDDStats oldRDD, long newNumBlocks) {
- if (oldRDD.cpVar == null) {
- throw new DMLRuntimeException("Cannot transform RDDStats without VarStats");
+ /**
+ * Initiates RDD statistics object for
+ * intermediate variables (not bound to {@code VarStats}).
+ * Intended to be used for intermediate shuffle estimations.
+ *
+ * @param size distributed size of the object
+ * @param partitions target number of partitions;
+ * -1 for fitting to HDFS block size
+ */
+ public RDDStats(long size, int partitions) {
+ checkpoint = false;
+ isCollected = false;
+ hashPartitioned = false;
+ // RDD specific characteristics not initialized -> simulates lazy evaluation
+ distributedSize = size;
+ if (partitions < 0) {
+ numPartitions = getNumPartitions();
+ } else {
+ numPartitions = partitions;
}
- RDDStats newRDD = new RDDStats(oldRDD.cpVar);
- newRDD.numBlocks = newNumBlocks;
- return newRDD;
+ cost = -1;
+ }
+
+ private int getNumPartitions() {
+ if (distributedSize < 0) {
+ throw new RuntimeException("Estimating number of partitions requires valid distributed RDD size");
+ } else if (distributedSize > 0) {
+ long hdfsBlockSize = InfrastructureAnalyzer.getHDFSBlockSize();
+ return (int) Math.max((distributedSize + hdfsBlockSize - 1) / hdfsBlockSize, 1);
+ }
+ return -1; // for scalars
+ }
+
+ /**
+ * Meant to be used at testing
+ * @return estimated time (seconds) for generation of the current RDD
+ */
+ public double getCost() {
+ return cost;
+ }
+
+ /**
+ * Meant to be used at testing
+ * @return flag if the current RDD is collected
+ */
+ public boolean isCollected() {
+ return isCollected;
+ }
+
+ private static long estimateDistributedSize(VarStats sourceStats) {
+ if (sourceStats.isScalar())
+ return 0; // 0 so it is non-negative
+ if (sourceStats.getCells() < 0)
+ throw new RuntimeException("Estimated size for RDD object is negative");
+ return OptimizerUtils.estimatePartitionedSizeExactSparsity(
+ sourceStats.getM(),
+ sourceStats.getN(),
+ ConfigurationManager.getBlocksize(),
+ sourceStats.getSparsity()
+ );
}
}
diff --git a/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java b/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java
new file mode 100644
index 0000000..479cf0f
--- /dev/null
+++ b/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java
@@ -0,0 +1,812 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.resource.cost;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.*;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
+import org.apache.sysds.runtime.instructions.spark.*;
+import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType;
+import org.apache.sysds.runtime.matrix.operators.CMOperator;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+
+import static org.apache.sysds.lops.DataGen.*;
+import static org.apache.sysds.resource.cost.IOCostUtils.*;
+
+public class SparkCostUtils {
+
+ public static double getReblockInstTime(String opcode, VarStats input, VarStats output, IOMetrics executorMetrics) {
+ // Reblock triggers a new stage
+ // old stage: read text file + shuffle the intermediate text rdd
+ double readTime = getHadoopReadTime(input, executorMetrics);
+ long sizeTextFile = OptimizerUtils.estimateSizeTextOutput(input.getM(), input.getN(), input.getNNZ(), (Types.FileFormat) input.fileInfo[1]);
+ RDDStats textRdd = new RDDStats(sizeTextFile, -1);
+ double shuffleTime = getSparkShuffleTime(textRdd, executorMetrics, true);
+ double timeStage1 = readTime + shuffleTime;
+ // new stage: transform partitioned shuffled text object into partitioned binary object
+ long nflop = getInstNFLOP(SPType.Reblock, opcode, output);
+ double timeStage2 = getCPUTime(nflop, textRdd.numPartitions, executorMetrics, output.rddStats, textRdd);
+ return timeStage1 + timeStage2;
+ }
+
+ public static double getRandInstTime(String opcode, int randType, VarStats output, IOMetrics executorMetrics) {
+ if (opcode.equals(SAMPLE_OPCODE)) {
+ // sample uses sortByKey() op. and it should be handled differently
+ throw new RuntimeException("Spark operation Rand with opcode " + SAMPLE_OPCODE + " is not supported yet");
+ }
+
+ long nflop;
+ if (opcode.equals(RAND_OPCODE) || opcode.equals(FRAME_OPCODE)) {
+ if (randType == 0) return 0; // empty matrix
+ else if (randType == 1) nflop = 8; // allocate, array fill
+ else if (randType == 2) nflop = 32; // full rand
+ else throw new RuntimeException("Unknown type of random instruction");
+ } else if (opcode.equals(SEQ_OPCODE)) {
+ nflop = 1;
+ } else {
+ throw new DMLRuntimeException("Rand operation with opcode '" + opcode + "' is not supported by SystemDS");
+ }
+ nflop *= output.getCells();
+ // no shuffling required -> only computation time
+ return getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats);
+ }
+
+ public static double getUnaryInstTime(String opcode, VarStats input, VarStats output, IOMetrics executorMetrics) {
+ // handles operations of type Builtin as Unary
+ // Unary adds a map() to an open stage
+ long nflop = getInstNFLOP(SPType.Unary, opcode, output, input);
+ double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats);
+ // the resulting rdd is being hash-partitioned depending on the input one
+ output.rddStats.hashPartitioned = input.rddStats.hashPartitioned;
+ return mapTime;
+ }
+
+ public static double getAggUnaryInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOMetrics executorMetrics) {
+ // AggregateUnary results in different Spark execution plan depending on the output dimensions
+ String opcode = inst.getOpcode();
+ AggBinaryOp.SparkAggType aggType = (inst instanceof AggregateUnarySPInstruction)?
+ ((AggregateUnarySPInstruction) inst).getAggType():
+ ((AggregateUnarySketchSPInstruction) inst).getAggType();
+ double shuffleTime;
+ if (inst instanceof CumulativeAggregateSPInstruction) {
+ shuffleTime = getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ output.rddStats.hashPartitioned = true;
+ } else {
+ if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
+ // loading RDD to the driver (CP) explicitly (not triggered by CP instruction)
+ output.rddStats.isCollected = true;
+ // cost for transferring result values (result from fold()) is negligible -> cost = computation time
+ shuffleTime = 0;
+ } else if (aggType == AggBinaryOp.SparkAggType.MULTI_BLOCK) {
+ // combineByKey() triggers a new stage -> cost = computation time + shuffle time (combineByKey);
+ if (opcode.equals("uaktrace")) {
+ long diagonalBlockSize = OptimizerUtils.estimatePartitionedSizeExactSparsity(
+ input.characteristics.getBlocksize() * input.getM(),
+ input.characteristics.getBlocksize(),
+ input.characteristics.getBlocksize(),
+ input.getNNZ()
+ );
+ RDDStats filteredRDD = new RDDStats(diagonalBlockSize, input.rddStats.numPartitions);
+ shuffleTime = getSparkShuffleTime(filteredRDD, executorMetrics, true);
+ } else {
+ shuffleTime = getSparkShuffleTime(input.rddStats, executorMetrics, true);
+ }
+ output.rddStats.hashPartitioned = true;
+ output.rddStats.numPartitions = input.rddStats.numPartitions;
+ } else { // aggType == AggBinaryOp.SparkAggType.NONE
+ output.rddStats.hashPartitioned = input.rddStats.hashPartitioned;
+ output.rddStats.numPartitions = input.rddStats.numPartitions;
+ // only mapping transformation -> cost = computation time
+ shuffleTime = 0;
+ }
+ }
+ long nflop = getInstNFLOP(SPType.AggregateUnary, opcode, output, input);
+ double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats);
+ return shuffleTime + mapTime;
+ }
+
+ public static double getIndexingInstTime(IndexingSPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) {
+ String opcode = inst.getOpcode();
+ double dataTransmissionTime;
+ if (opcode.equals(RightIndex.OPCODE)) {
+ // assume direct collecting if output dimensions not larger than block size
+ int blockSize = ConfigurationManager.getBlocksize();
+ if (output.getM() <= blockSize && output.getN() <= blockSize) {
+ // represents single block and multi block cases
+ dataTransmissionTime = getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ output.rddStats.isCollected = true;
+ } else {
+ // represents general indexing: worst case: shuffling required
+ dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ }
+ } else if (opcode.equals(LeftIndex.OPCODE)) {
+ // model combineByKey() with shuffling the second input
+ dataTransmissionTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true);
+ } else { // mapLeftIndex
+ dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ }
+ long nflop = getInstNFLOP(SPType.MatrixIndexing, opcode, output);
+ // scan only the size of the output since filter is applied first
+ RDDStats[] objectsToScan = (input2 == null)? new RDDStats[]{output.rddStats} :
+ new RDDStats[]{output.rddStats, output.rddStats};
+ double mapTime = getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, objectsToScan);
+ return dataTransmissionTime + mapTime;
+ }
+
+ public static double getBinaryInstTime(SPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) {
+ SPType opType = inst.getSPInstructionType();
+ String opcode = inst.getOpcode();
+ // binary, builtin binary (log and log_nz)
+ // for the NFLOP calculation if the function is executed as map is not relevant
+ if (opcode.startsWith("map")) {
+ opcode = opcode.substring(3);
+ }
+ double dataTransmissionTime;
+ if (inst instanceof BinaryMatrixMatrixSPInstruction) {
+ if (inst instanceof BinaryMatrixBVectorSPInstruction) {
+ // the second matrix is always the broadcast one
+ dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ // flatMapToPair() or ()mapPartitionsToPair invoked -> no shuffling
+ output.rddStats.numPartitions = input1.rddStats.numPartitions;
+ output.rddStats.hashPartitioned = input1.rddStats.hashPartitioned;
+ } else { // regular BinaryMatrixMatrixSPInstruction
+ // join() input1 and input2
+ dataTransmissionTime = getSparkShuffleWriteTime(input1.rddStats, executorMetrics) +
+ getSparkShuffleWriteTime(input2.rddStats, executorMetrics);
+ if (input1.rddStats.hashPartitioned) {
+ output.rddStats.numPartitions = input1.rddStats.numPartitions;
+ if (!input2.rddStats.hashPartitioned || !(input1.rddStats.numPartitions == input2.rddStats.numPartitions)) {
+ // shuffle needed for join() -> actual shuffle only for input2
+ dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) +
+ getSparkShuffleReadTime(input2.rddStats, executorMetrics);
+ } else { // no shuffle needed for join() -> only read from local disk
+ dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) +
+ getSparkShuffleReadStaticTime(input2.rddStats, executorMetrics);
+ }
+ } else if (input2.rddStats.hashPartitioned) {
+ output.rddStats.numPartitions = input2.rddStats.numPartitions;
+ // input1 not hash partitioned: shuffle needed for join() -> actual shuffle only for input2
+ dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) +
+ getSparkShuffleReadTime(input2.rddStats, executorMetrics);
+ } else {
+ // repartition all data needed
+ output.rddStats.numPartitions = 2 * output.rddStats.numPartitions;
+ dataTransmissionTime += getSparkShuffleReadTime(input1.rddStats, executorMetrics) +
+ getSparkShuffleReadTime(input2.rddStats, executorMetrics);
+ }
+ output.rddStats.hashPartitioned = true;
+ }
+ } else if (inst instanceof BinaryMatrixScalarSPInstruction) {
+ // only mapValues() invoked -> no shuffling
+ dataTransmissionTime = 0;
+ output.rddStats.hashPartitioned = (input2.isScalar())? input1.rddStats.hashPartitioned : input2.rddStats.hashPartitioned;
+ } else if (inst instanceof BinaryFrameMatrixSPInstruction || inst instanceof BinaryFrameFrameSPInstruction) {
+ throw new RuntimeException("Handling binary instructions for frames not handled yet.");
+ } else {
+ throw new RuntimeException("Not supported binary instruction: "+inst);
+ }
+ long nflop = getInstNFLOP(opType, opcode, output, input1, input2);
+ double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats);
+ return dataTransmissionTime + mapTime;
+ }
+
+ public static double getAppendInstTime(AppendSPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) {
+ double dataTransmissionTime;
+ if (inst instanceof AppendMSPInstruction) {
+ dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ output.rddStats.hashPartitioned = true;
+ } else if (inst instanceof AppendRSPInstruction) {
+ dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, false);
+ } else if (inst instanceof AppendGAlignedSPInstruction) {
+ // only changing matrix indexing
+ dataTransmissionTime = 0;
+ } else { // AppendGSPInstruction
+ // shuffle the whole appended matrix
+ dataTransmissionTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true);
+ output.rddStats.hashPartitioned = true;
+ }
+ // opcode not relevant for the nflop estimation of append instructions;
+ long nflop = getInstNFLOP(inst.getSPInstructionType(), "append", output, input1, input2);
+ double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats);
+ return dataTransmissionTime + mapTime;
+ }
+
+ public static double getReorgInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOMetrics executorMetrics) {
+ // includes logic for MatrixReshapeSPInstruction
+ String opcode = inst.getOpcode();
+ double dataTransmissionTime;
+ switch (opcode) {
+ case "rshape":
+ dataTransmissionTime = getSparkShuffleTime(input.rddStats, executorMetrics, true);
+ output.rddStats.hashPartitioned = true;
+ break;
+ case "r'":
+ dataTransmissionTime = 0;
+ output.rddStats.hashPartitioned = input.rddStats.hashPartitioned;
+ break;
+ case "rev":
+ dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ output.rddStats.hashPartitioned = true;
+ break;
+ case "rdiag":
+ dataTransmissionTime = 0;
+ output.rddStats.numPartitions = input.rddStats.numPartitions;
+ output.rddStats.hashPartitioned = input.rddStats.hashPartitioned;
+ break;
+ default: // rsort
+ String ixretAsString = InstructionUtils.getInstructionParts(inst.getInstructionString())[4];
+ boolean ixret = ixretAsString.equalsIgnoreCase("true");
+ int shuffleFactor;
+ if (ixret) { // index return
+ shuffleFactor = 2; // estimate cost for 2 shuffles
+ } else {
+ shuffleFactor = 4;// estimate cost for 2 shuffles
+ }
+ // assume case: 4 times shuffling the output
+ dataTransmissionTime = getSparkShuffleWriteTime(output.rddStats, executorMetrics) +
+ getSparkShuffleReadTime(output.rddStats, executorMetrics);
+ dataTransmissionTime *= shuffleFactor;
+ break;
+ }
+ long nflop = getInstNFLOP(inst.getSPInstructionType(), opcode, output); // uses output only
+ double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats);
+ return dataTransmissionTime + mapTime;
+ }
+
+ public static double getTSMMInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) {
+ String opcode = inst.getOpcode();
+ MMTSJ.MMTSJType type;
+
+ double dataTransmissionTime;
+ if (inst instanceof TsmmSPInstruction) {
+ type = ((TsmmSPInstruction) inst).getMMTSJType();
+ // fold() used but result is still a whole matrix block
+ dataTransmissionTime = getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ output.rddStats.isCollected = true;
+ } else { // Tsmm2SPInstruction
+ type = ((Tsmm2SPInstruction) inst).getMMTSJType();
+ // assumes always default output with collect
+ long rowsRange = (type == MMTSJ.MMTSJType.LEFT)? input.getM() :
+ input.getM() - input.characteristics.getBlocksize();
+ long colsRange = (type != MMTSJ.MMTSJType.LEFT)? input.getN() :
+ input.getN() - input.characteristics.getBlocksize();
+ VarStats broadcast = new VarStats("tmp1", new MatrixCharacteristics(rowsRange, colsRange));
+ broadcast.rddStats = new RDDStats(broadcast);
+ dataTransmissionTime = getSparkCollectTime(broadcast.rddStats, driverMetrics, executorMetrics);
+ dataTransmissionTime += getSparkBroadcastTime(broadcast, driverMetrics, executorMetrics);
+ dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ }
+ opcode += type.isLeft() ? "_left" : "_right";
+ long nflop = getInstNFLOP(inst.getSPInstructionType(), opcode, output, input);
+ double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats);
+ return dataTransmissionTime + mapTime;
+ }
+
+ public static double getCentralMomentInstTime(CentralMomentSPInstruction inst, VarStats input, VarStats weights, VarStats output, IOMetrics executorMetrics) {
+ CMOperator.AggregateOperationTypes opType = ((CMOperator) inst.getOperator()).getAggOpType();
+ String opcode = inst.getOpcode() + "_" + opType.name().toLowerCase();
+
+ double dataTransmissionTime = 0;
+ if (weights != null) {
+ dataTransmissionTime = getSparkShuffleWriteTime(weights.rddStats, executorMetrics) +
+ getSparkShuffleReadTime(weights.rddStats, executorMetrics);
+
+ }
+ output.rddStats.isCollected = true;
+
+ RDDStats[] RDDInputs = (weights == null)? new RDDStats[]{input.rddStats} : new RDDStats[]{input.rddStats, weights.rddStats};
+ long nflop = getInstNFLOP(inst.getSPInstructionType(), opcode, output, input);
+ double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, RDDInputs);
+ return dataTransmissionTime + mapTime;
+ }
+
+ public static double getCastInstTime(CastSPInstruction inst, VarStats input, VarStats output, IOMetrics executorMetrics) {
+ double shuffleTime = 0;
+ if (input.getN() > input.characteristics.getBlocksize()) {
+ shuffleTime = getSparkShuffleWriteTime(input.rddStats, executorMetrics) +
+ getSparkShuffleReadTime(input.rddStats, executorMetrics);
+ output.rddStats.hashPartitioned = true;
+ }
+ long nflop = getInstNFLOP(inst.getSPInstructionType(), inst.getOpcode(), output, input);
+ double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats);
+ return shuffleTime + mapTime;
+ }
+
+ public static double getQSortInstTime(QuantileSortSPInstruction inst, VarStats input, VarStats weights, VarStats output, IOMetrics executorMetrics) {
+ String opcode = inst.getOpcode();
+ double shuffleTime = 0;
+ if (weights != null) {
+ opcode += "_wts";
+ shuffleTime += getSparkShuffleWriteTime(weights.rddStats, executorMetrics) +
+ getSparkShuffleReadTime(weights.rddStats, executorMetrics);
+ }
+ shuffleTime += getSparkShuffleWriteTime(output.rddStats, executorMetrics) +
+ getSparkShuffleReadTime(output.rddStats, executorMetrics);
+ output.rddStats.hashPartitioned = true;
+
+ long nflop = getInstNFLOP(SPType.QSort, opcode, output, input, weights);
+ RDDStats[] RDDInputs = (weights == null)? new RDDStats[]{input.rddStats} : new RDDStats[]{input.rddStats, weights.rddStats};
+ double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, RDDInputs);
+ return shuffleTime + mapTime;
+ }
+
+ public static double getMatMulInstTime(BinarySPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) {
+ double dataTransmissionTime;
+ int numPartitionsForMapping;
+ if (inst instanceof CpmmSPInstruction) {
+ CpmmSPInstruction cpmminst = (CpmmSPInstruction) inst;
+ AggBinaryOp.SparkAggType aggType = cpmminst.getAggType();
+ // estimate for in1.join(in2)
+ long joinedSize = input1.rddStats.distributedSize + input2.rddStats.distributedSize;
+ RDDStats joinedRDD = new RDDStats(joinedSize, -1);
+ dataTransmissionTime = getSparkShuffleTime(joinedRDD, executorMetrics, true);
+ if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
+ dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ output.rddStats.isCollected = true;
+ } else {
+ dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ output.rddStats.hashPartitioned = true;
+ }
+ numPartitionsForMapping = joinedRDD.numPartitions;
+ } else if (inst instanceof RmmSPInstruction) {
+ // estimate for in1.join(in2)
+ long joinedSize = input1.rddStats.distributedSize + input2.rddStats.distributedSize;
+ RDDStats joinedRDD = new RDDStats(joinedSize, -1);
+ dataTransmissionTime = getSparkShuffleTime(joinedRDD, executorMetrics, true);
+ // estimate for out.combineByKey() per partition
+ dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, false);
+ output.rddStats.hashPartitioned = true;
+ numPartitionsForMapping = joinedRDD.numPartitions;
+ } else if (inst instanceof MapmmSPInstruction) {
+ dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ MapmmSPInstruction mapmminst = (MapmmSPInstruction) inst;
+ AggBinaryOp.SparkAggType aggType = mapmminst.getAggType();
+ if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
+ dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ output.rddStats.isCollected = true;
+ } else {
+ dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ output.rddStats.hashPartitioned = true;
+ }
+ numPartitionsForMapping = input1.rddStats.numPartitions;
+ } else if (inst instanceof PmmSPInstruction) {
+ dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ output.rddStats.numPartitions = input1.rddStats.numPartitions;
+ dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ output.rddStats.hashPartitioned = true;
+ numPartitionsForMapping = input1.rddStats.numPartitions;
+ } else if (inst instanceof ZipmmSPInstruction) {
+ // assume always a shuffle without data re-distribution
+ dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, false);
+ dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ numPartitionsForMapping = input1.rddStats.numPartitions;
+ output.rddStats.isCollected = true;
+ } else if (inst instanceof PMapmmSPInstruction) {
+ throw new RuntimeException("PMapmmSPInstruction instruction is still experimental and not supported yet");
+ } else {
+ throw new RuntimeException(inst.getClass().getName() + " instruction is not handled by the current method");
+ }
+ long nflop = getInstNFLOP(inst.getSPInstructionType(), inst.getOpcode(), output, input1, input2);
+ double mapTime;
+ if (inst instanceof MapmmSPInstruction || inst instanceof PmmSPInstruction) {
+ // scan only first input
+ mapTime = getCPUTime(nflop, numPartitionsForMapping, executorMetrics, output.rddStats, input1.rddStats);
+ } else {
+ mapTime = getCPUTime(nflop, numPartitionsForMapping, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats);
+ }
+ return dataTransmissionTime + mapTime;
+ }
+
+ public static double getMatMulChainInstTime(MapmmChainSPInstruction inst, VarStats input1, VarStats input2, VarStats input3, VarStats output,
+ IOMetrics driverMetrics, IOMetrics executorMetrics) {
+ double dataTransmissionTime = 0;
+ if (input3 != null) {
+ dataTransmissionTime += getSparkBroadcastTime(input3, driverMetrics, executorMetrics);
+ }
+ dataTransmissionTime += getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
+ output.rddStats.isCollected = true;
+
+ long nflop = getInstNFLOP(SPType.MAPMMCHAIN, inst.getOpcode(), output, input1, input2);
+ double mapTime = getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats);
+ return dataTransmissionTime + mapTime;
+ }
+
+ public static double getCtableInstTime(CtableSPInstruction tableInst, VarStats input1, VarStats input2, VarStats input3, VarStats output, IOMetrics executorMetrics) {
+ String opcode = tableInst.getOpcode();
+ double shuffleTime;
+ if (opcode.equals("ctableexpand") || !input2.isScalar() && input3.isScalar()) { // CTABLE_EXPAND_SCALAR_WEIGHT/CTABLE_TRANSFORM_SCALAR_WEIGHT
+ // in1.join(in2)
+ shuffleTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true);
+ } else if (input2.isScalar() && input3.isScalar()) { // CTABLE_TRANSFORM_HISTOGRAM
+ // no joins
+ shuffleTime = 0;
+ } else if (input2.isScalar() && !input3.isScalar()) { // CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM
+ // in1.join(in3)
+ shuffleTime = getSparkShuffleTime(input3.rddStats, executorMetrics, true);
+ } else { // CTABLE_TRANSFORM
+ // in1.join(in2).join(in3)
+ shuffleTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true);
+ shuffleTime += getSparkShuffleTime(input3.rddStats, executorMetrics, true);
+ }
+ // combineByKey()
+ shuffleTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ output.rddStats.hashPartitioned = true;
+
+ long nflop = getInstNFLOP(SPType.Ctable, opcode, output, input1, input2, input3);
+ double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics,
+ output.rddStats, input1.rddStats, input2.rddStats, input3.rddStats);
+
+ return shuffleTime + mapTime;
+ }
+
+ public static double getParameterizedBuiltinInstTime(ParameterizedBuiltinSPInstruction paramInst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) {
+ String opcode = paramInst.getOpcode();
+ double dataTransmissionTime;
+ switch (opcode) {
+ case "rmempty":
+ if (input2.rddStats == null) // broadcast
+ dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ else // join
+ dataTransmissionTime = getSparkShuffleTime(input1.rddStats, executorMetrics, true);
+ dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true);
+ break;
+ case "contains":
+ if (input2.isScalar()) {
+ dataTransmissionTime = 0;
+ } else {
+ dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
+ // ignore reduceByKey() cost
+ }
+ output.rddStats.isCollected = true;
+ break;
+ case "replace":
+ case "lowertri":
+ case "uppertri":
+ dataTransmissionTime = 0;
+ break;
+ default:
+ throw new RuntimeException("Spark operation ParameterizedBuiltin with opcode " + opcode + " is not supported yet");
+ }
+
+ long nflop = getInstNFLOP(paramInst.getSPInstructionType(), opcode, output, input1);
+ double mapTime = getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats);
+
+ return dataTransmissionTime + mapTime;
+ }
+
+ /**
+ * Computes an estimate for the time needed by the CPU to execute (including memory access)
+ * an instruction by providing number of floating operations.
+ *
+ * @param nflop number FLOP to execute a target CPU operation
+ * @param numPartitions number partitions used to execute the target operation;
+ * not bound to any of the input/output statistics object to allow more
+ * flexibility depending on the corresponding instruction
+ * @param executorMetrics metrics for the executor utilized by the Spark cluster
+ * @param output statistics for the output variable
+ * @param inputs arrays of statistics for the output variable
+ * @return time estimate
+ */
+ public static double getCPUTime(long nflop, int numPartitions, IOMetrics executorMetrics, RDDStats output, RDDStats...inputs) {
+ double memScanTime = 0;
+ for (RDDStats input: inputs) {
+ if (input == null) continue;
+ // compensates for spill-overs to account for non-compute bound operations
+ memScanTime += getMemReadTime(input, executorMetrics);
+ }
+ double numWaves = Math.ceil((double) numPartitions / SparkExecutionContext.getDefaultParallelism(false));
+ double scaledNFLOP = (numWaves * nflop) / numPartitions;
+ double cpuComputationTime = scaledNFLOP / executorMetrics.cpuFLOPS;
+ double memWriteTime = output != null? getMemWriteTime(output, executorMetrics) : 0;
+ return Math.max(memScanTime, cpuComputationTime) + memWriteTime;
+ }
+
+ public static void assignOutputRDDStats(SPInstruction inst, VarStats output, VarStats...inputs) {
+ if (!output.isScalar()) {
+ SPType instType = inst.getSPInstructionType();
+ String opcode = inst.getOpcode();
+ if (output.getCells() < 0) {
+ inferStats(instType, opcode, output, inputs);
+ }
+ }
+ output.rddStats = new RDDStats(output);
+ }
+
+ private static void inferStats(SPType instType, String opcode, VarStats output, VarStats...inputs) {
+ switch (instType) {
+ case Unary:
+ case Builtin:
+ CPCostUtils.inferStats(CPType.Unary, opcode, output, inputs);
+ break;
+ case AggregateUnary:
+ case AggregateUnarySketch:
+ CPCostUtils.inferStats(CPType.AggregateUnary, opcode, output, inputs);
+ case MatrixIndexing:
+ CPCostUtils.inferStats(CPType.MatrixIndexing, opcode, output, inputs);
+ break;
+ case Reorg:
+ CPCostUtils.inferStats(CPType.Reorg, opcode, output, inputs);
+ break;
+ case Binary:
+ CPCostUtils.inferStats(CPType.Binary, opcode, output, inputs);
+ break;
+ case CPMM:
+ case RMM:
+ case MAPMM:
+ case PMM:
+ case ZIPMM:
+ CPCostUtils.inferStats(CPType.AggregateBinary, opcode, output, inputs);
+ break;
+ case ParameterizedBuiltin:
+ CPCostUtils.inferStats(CPType.ParameterizedBuiltin, opcode, output, inputs);
+ break;
+ case Rand:
+ CPCostUtils.inferStats(CPType.Rand, opcode, output, inputs);
+ break;
+ case Ctable:
+ CPCostUtils.inferStats(CPType.Ctable, opcode, output, inputs);
+ break;
+ default:
+ throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has no formula for inferring dimensions");
+ }
+ if (output.getCells() < 0) {
+ throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has incomplete formula for inferring dimensions");
+ }
+ }
+
+ private static long getInstNFLOP(
+ SPType instructionType,
+ String opcode,
+ VarStats output,
+ VarStats...inputs
+ ) {
+ opcode = opcode.toLowerCase();
+ double costs;
+ switch (instructionType) {
+ case Reblock:
+ if (opcode.startsWith("libsvm")) {
+ return output.getCellsWithSparsity();
+ } else { // starts with "rblk" or "csvrblk"
+ return output.getCells();
+ }
+ case Unary:
+ case Builtin:
+ return CPCostUtils.getInstNFLOP(CPType.Unary, opcode, output, inputs);
+ case AggregateUnary:
+ case AggregateUnarySketch:
+ switch (opcode) {
+ case "uacdr":
+ case "uacdc":
+ throw new DMLRuntimeException(opcode + " opcode is not implemented by SystemDS");
+ default:
+ return CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, inputs);
+ }
+ case CumsumAggregate:
+ switch (opcode) {
+ case "ucumack+":
+ case "ucumac*":
+ case "ucumacmin":
+ case "ucumacmax":
+ costs = 1; break;
+ case "ucumac+*":
+ costs = 2; break;
+ default:
+ throw new DMLRuntimeException(opcode + " opcode is not implemented by SystemDS");
+ }
+ return (long) (costs * inputs[0].getCells() + costs * output.getN());
+ case TSMM:
+ case TSMM2:
+ return CPCostUtils.getInstNFLOP(CPType.MMTSJ, opcode, output, inputs);
+ case Reorg:
+ case MatrixReshape:
+ return CPCostUtils.getInstNFLOP(CPType.Reorg, opcode, output, inputs);
+ case MatrixIndexing:
+ // the actual opcode value is not used at the moment
+ return CPCostUtils.getInstNFLOP(CPType.MatrixIndexing, opcode, output, inputs);
+ case Cast:
+ return output.getCellsWithSparsity();
+ case QSort:
+ return CPCostUtils.getInstNFLOP(CPType.QSort, opcode, output, inputs);
+ case CentralMoment:
+ return CPCostUtils.getInstNFLOP(CPType.CentralMoment, opcode, output, inputs);
+ case UaggOuterChain:
+ case Dnn:
+ throw new RuntimeException("Spark operation type'" + instructionType + "' is not supported yet");
+ // types corresponding to BinaryCPInstruction
+ case Binary:
+ switch (opcode) {
+ case "+*":
+ case "-*":
+ // original "map+*" and "map-*"
+ // "+*" and "-*" defined as ternary
+ throw new RuntimeException("Spark operation with opcode '" + opcode + "' is not supported yet");
+ default:
+ return CPCostUtils.getInstNFLOP(CPType.Binary, opcode, output, inputs);
+ }
+ case CPMM:
+ case RMM:
+ case MAPMM:
+ case PMM:
+ case ZIPMM:
+ case PMAPMM:
+ // do not reduce by factor of 2: not explicit matrix multiplication
+ return 2 * CPCostUtils.getInstNFLOP(CPType.AggregateBinary, opcode, output, inputs);
+ case MAPMMCHAIN:
+ return 2 * inputs[0].getCells() * inputs[0].getN() // ba(+*)
+ + 2 * inputs[0].getM() * inputs[1].getN() // cellwise b(*) + r(t)
+ + 2 * inputs[0].getCellsWithSparsity() * inputs[1].getN() // ba(+*)
+ + inputs[1].getM() * output.getM() ; //r(t)
+ case BinUaggChain:
+ break;
+ case MAppend:
+ case RAppend:
+ case GAppend:
+ case GAlignedAppend:
+ // the actual opcode value is not used at the moment
+ return CPCostUtils.getInstNFLOP(CPType.Append, opcode, output, inputs);
+ case BuiltinNary:
+ return CPCostUtils.getInstNFLOP(CPType.BuiltinNary, opcode, output, inputs);
+ case Ctable:
+ return CPCostUtils.getInstNFLOP(CPType.Ctable, opcode, output, inputs);
+ case ParameterizedBuiltin:
+ return CPCostUtils.getInstNFLOP(CPType.ParameterizedBuiltin, opcode, output, inputs);
+ default:
+ // all existing cases should have been handled above
+ throw new DMLRuntimeException("Spark operation type'" + instructionType + "' is not supported by SystemDS");
+ }
+ throw new RuntimeException();
+ }
+
+
+// //ternary aggregate operators
+// case "tak+*":
+// break;
+// case "tack+*":
+// break;
+// // Neural network operators
+// case "conv2d":
+// case "conv2d_bias_add":
+// case "maxpooling":
+// case "relu_maxpooling":
+// case RightIndex.OPCODE:
+// case LeftIndex.OPCODE:
+// case "mapLeftIndex":
+// case "_map",:
+// break;
+// // Spark-specific instructions
+// case Checkpoint.DEFAULT_CP_OPCODE,:
+// break;
+// case Checkpoint.ASYNC_CP_OPCODE,:
+// break;
+// case Compression.OPCODE,:
+// break;
+// case DeCompression.OPCODE,:
+// break;
+// // Parameterized Builtin Functions
+// case "autoDiff",:
+// break;
+// case "contains",:
+// break;
+// case "groupedagg",:
+// break;
+// case "mapgroupedagg",:
+// break;
+// case "rmempty",:
+// break;
+// case "replace",:
+// break;
+// case "rexpand",:
+// break;
+// case "lowertri",:
+// break;
+// case "uppertri",:
+// break;
+// case "tokenize",:
+// break;
+// case "transformapply",:
+// break;
+// case "transformdecode",:
+// break;
+// case "transformencode",:
+// break;
+// case "mappend",:
+// break;
+// case "rappend",:
+// break;
+// case "gappend",:
+// break;
+// case "galignedappend",:
+// break;
+// //ternary instruction opcodes
+// case "ctable",:
+// break;
+// case "ctableexpand",:
+// break;
+//
+// //ternary instruction opcodes
+// case "+*",:
+// break;
+// case "-*",:
+// break;
+// case "ifelse",:
+// break;
+//
+// //quaternary instruction opcodes
+// case WeightedSquaredLoss.OPCODE,:
+// break;
+// case WeightedSquaredLossR.OPCODE,:
+// break;
+// case WeightedSigmoid.OPCODE,:
+// break;
+// case WeightedSigmoidR.OPCODE,:
+// break;
+// case WeightedDivMM.OPCODE,:
+// break;
+// case WeightedDivMMR.OPCODE,:
+// break;
+// case WeightedCrossEntropy.OPCODE,:
+// break;
+// case WeightedCrossEntropyR.OPCODE,:
+// break;
+// case WeightedUnaryMM.OPCODE,:
+// break;
+// case WeightedUnaryMMR.OPCODE,:
+// break;
+// case "bcumoffk+":
+// break;
+// case "bcumoff*":
+// break;
+// case "bcumoff+*":
+// break;
+// case "bcumoffmin",:
+// break;
+// case "bcumoffmax",:
+// break;
+//
+// //central moment, covariance, quantiles (sort/pick)
+// case "cm" ,:
+// break;
+// case "cov" ,:
+// break;
+// case "qsort" ,:
+// break;
+// case "qpick" ,:
+// break;
+//
+// case "binuaggchain",:
+// break;
+//
+// case "write" ,:
+// break;
+//
+//
+// case "spoof":
+// break;
+// default:
+// throw RuntimeException("No complexity factor for op. code: " + opcode);
+// }
+}
diff --git a/src/main/java/org/apache/sysds/resource/cost/VarStats.java b/src/main/java/org/apache/sysds/resource/cost/VarStats.java
index 93bbd0e..0ec34aa 100644
--- a/src/main/java/org/apache/sysds/resource/cost/VarStats.java
+++ b/src/main/java/org/apache/sysds/resource/cost/VarStats.java
@@ -24,66 +24,104 @@
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-public class VarStats
+public class VarStats
{
- MatrixCharacteristics _mc;
+ // helps for debugging + carries value for scalar literals
+ String varName;
/**
- * Size in memory estimate
+ * <li>null if scalar</li>
+ * <li>initialized if Matrix or Frame</li>
+ */
+ MatrixCharacteristics characteristics;
+ /**
+ * estimated size in memory
* <li>-1 if not in memory yet</li>
* <li>0 if scalar</li>
+ * <li>=>1 estimated loaded size in Bytes</li>
*/
- long _memory;
+ long allocatedMemory;
+ // refCount/selfRefCount cases of variables copying (operations 'çpvar' or 'fcall')
+ // increase/decrease only one of them at a time (selfRefCount is not a refCount)
+ int refCount;
+ int selfRefCount;
/**
- * true if object modified since last saved, or
- * if HDFS file still doesn't exist
+ * Always contains 2 elements:
+ * first elements: {@code String} with the source type (hdfs, s3 or local)
+ * second element: {@code Types.FileFormat} value
*/
- boolean _dirty = false;
+ Object[] fileInfo = null;
+ RDDStats rddStats = null;
- RDDStats _rdd = null;
-
- Object[] _fileInfo = null;
-
- public VarStats(DataCharacteristics dc) {
- this(dc, -1);
+ public VarStats(String name, DataCharacteristics dc) {
+ varName = name;
+ if (dc == null) {
+ characteristics = null; // for scalar
+ allocatedMemory = 0;
+ } else if (dc instanceof MatrixCharacteristics) {
+ characteristics = (MatrixCharacteristics) dc;
+ allocatedMemory = -1;
+ } else {
+ throw new RuntimeException("Unexpected error: expecting MatrixCharacteristics or null");
+ }
+ refCount = 1;
+ selfRefCount = 1;
}
- public VarStats(DataCharacteristics dc, long sizeEstimate) {
- if (dc == null) {
- _mc = null;
- }
- else if (dc instanceof MatrixCharacteristics) {
- _mc = (MatrixCharacteristics) dc;
- } else {
- throw new RuntimeException("VarStats: expecting MatrixCharacteristics or null");
- }
- _memory = sizeEstimate;
+ public boolean isScalar() {
+ return characteristics == null;
}
public long getM() {
- return _mc.getRows();
+ return isScalar()? 1 : characteristics.getRows();
}
public long getN() {
- return _mc.getCols();
+ return isScalar()? 1 : characteristics.getCols();
}
- public double getS() {
- return _mc == null? 1.0 : OptimizerUtils.getSparsity(_mc);
+ public long getNNZ() {
+ return isScalar()? 1 : characteristics.getNonZerosBound();
+ }
+
+ public double getSparsity() {
+ return isScalar()? 1.0 : OptimizerUtils.getSparsity(characteristics);
}
public long getCells() {
- return _mc.getRows() * _mc.getCols();
+ return isScalar()? 1 : !characteristics.dimsKnown()? -1 :
+ characteristics.getLength();
}
- public double getCellsWithSparsity() {
- if (isSparse())
- return getCells() * getS();
- return (double) getCells();
+ public long getCellsWithSparsity() {
+ if (isScalar()) return 1;
+ return (long) (getCells() * getSparsity());
}
public boolean isSparse() {
- return MatrixBlock.evalSparseFormatInMemory(_mc);
+ return (!isScalar() && MatrixBlock.evalSparseFormatInMemory(characteristics));
}
- // clone() needed?
+ /**
+ * Meant to be used at testing
+ * @param memory size to allocate
+ */
+ public void setAllocatedMemory(long memory) {
+ allocatedMemory = memory;
+ }
+
+ /**
+ * Meant to be used at testing
+ * @return corresponding RDD statistics
+ */
+ public RDDStats getRddStats() {
+ return rddStats;
+ }
+
+ /**
+ * Meant to be used at testing
+ * @param rddStats corresponding RDD statistics
+ */
+ public void setRddStats(RDDStats rddStats) {
+ this.rddStats = rddStats;
+ }
}
diff --git a/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java b/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java
index 3893f01..9a51c78 100644
--- a/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java
+++ b/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java
@@ -317,7 +317,8 @@
// estimate execution time of the current program
// TODO: pass further relevant cluster configurations to cost estimator after extending it
// like for example: FLOPS, I/O and networking speed
- timeCost = CostEstimator.estimateExecutionTime(program) + CloudUtils.DEFAULT_CLUSTER_LAUNCH_TIME;
+ timeCost = CostEstimator.estimateExecutionTime(program, point.driverInstance, point.executorInstance)
+ + CloudUtils.DEFAULT_CLUSTER_LAUNCH_TIME;
} catch (CostEstimationException e) {
throw new RuntimeException(e.getMessage());
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 0710a4f..765da45 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -56,10 +56,12 @@
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.conf.CompilerConfig;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Checkpoint;
+import org.apache.sysds.resource.CloudUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.io.ReaderSparkCompressed;
@@ -160,10 +162,8 @@
}
public static void initLocalSparkContext(SparkConf sparkConf) {
- if (_sconf == null) {
- _sconf = new SparkClusterConfig();
- }
- _sconf.analyzeSparkConfiguation(sparkConf);
+ // allows re-initialization
+ _sconf = new SparkClusterConfig(sparkConf);
}
public synchronized static JavaSparkContext getSparkContextStatic() {
@@ -1884,6 +1884,23 @@
LOG.debug( this.toString() );
}
+ // Meant to be used only resource optimization
+ public SparkClusterConfig(SparkConf sconf)
+ {
+ _confOnly = true;
+
+ //parse version and config
+ String sparkVersion = CloudUtils.SPARK_VERSION;
+ _legacyVersion = (UtilFunctions.compareVersion(sparkVersion, "1.6.0") < 0
+ || sconf.getBoolean("spark.memory.useLegacyMode", false) );
+
+ //obtain basic spark configurations
+ if( _legacyVersion )
+ analyzeSparkConfiguationLegacy(sconf);
+ else
+ analyzeSparkConfiguation(sconf);
+ }
+
public long getBroadcastMemoryBudget() {
return (long) (_memExecutor * _memBroadcastFrac);
}
@@ -1978,6 +1995,11 @@
_defaultPar = 2;
_confOnly &= true;
}
+ else if (ConfigurationManager.getCompilerConfigFlag(CompilerConfig.ConfigType.RESOURCE_OPTIMIZATION)) {
+ _numExecutors = numExecutors;
+ _defaultPar = numExecutors * numCoresPerExec;
+ _confOnly = true;
+ }
else {
//get default parallelism (total number of executors and cores)
//note: spark context provides this information while conf does not
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 6611e1a..f199aef 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -56,10 +56,6 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.serialization.ObjectEncoder;
-import io.netty.handler.ssl.SslContext;
-import io.netty.handler.ssl.SslContextBuilder;
-import io.netty.handler.ssl.SslHandler;
-import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.util.concurrent.Promise;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index cb15422..e6f553d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -298,6 +298,10 @@
return _boundOutputNames;
}
+ public List<String> getFunArgNames() {
+ return _funArgNames;
+ }
+
public String updateInstStringFunctionName(String pattern, String replace)
{
//split current instruction
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
index fb19450..18f5613 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
@@ -148,4 +148,8 @@
if( r_op.fn instanceof DiagIndex && soresBlock.getNumColumns()>1 ) //diagV2M
ec.getMatrixObject(output.getName()).setDiag(true);
}
+
+ public CPOperand getIxRet() {
+ return _ixret;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index dcbdcc4..8826c41 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -758,7 +758,7 @@
}
private void setCacheableDataFields(CacheableData<?> obj){
- //clone meta data because it is updated on copy-on-write, otherwise there
+ //clone metadata because it is updated on copy-on-write, otherwise there
//is potential for hidden side effects between variables.
obj.setMetaData((MetaData)metadata.clone());
obj.enableCleanup(!getInput1().getName()
@@ -1336,7 +1336,7 @@
// Find a start position of file name string.
int iPos = StringUtils.ordinalIndexOf(instString, Lop.OPERAND_DELIMITOR, CREATEVAR_FILE_NAME_VAR_POS);
- // Find a end position of file name string.
+ // Find an end position of file name string.
int iPos2 = StringUtils.indexOf(instString, Lop.OPERAND_DELIMITOR, iPos+1);
StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
index fcd0760..a58b3ca 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
@@ -173,6 +173,10 @@
throw new NotImplementedException("Aggregate sketch instruction for tensors has not been implemented yet.");
}
+ public AggBinaryOp.SparkAggType getAggType() {
+ return aggtype;
+ }
+
private static class AggregateUnarySketchCreateFunction
implements Function<Tuple2<MatrixIndexes, MatrixBlock>, CorrMatrixBlock> {
private static final long serialVersionUID = 7295176181965491548L;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
index 80d25f2..c311850 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
@@ -66,8 +66,8 @@
public class BuiltinNarySPInstruction extends SPInstruction implements LineageTraceable
{
- private CPOperand[] inputs;
- private CPOperand output;
+ public CPOperand[] inputs;
+ public CPOperand output;
protected BuiltinNarySPInstruction(CPOperand[] in, CPOperand out, String opcode, String istr) {
super(SPType.BuiltinNary, opcode, istr);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
index e2f4e5d..68c5e55 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
@@ -44,25 +44,25 @@
public class MapmmChainSPInstruction extends SPInstruction implements LineageTraceable {
private ChainType _chainType = null;
- private CPOperand _input1 = null;
- private CPOperand _input2 = null;
- private CPOperand _input3 = null;
- private CPOperand _output = null;
+ public CPOperand input1 = null;
+ public CPOperand input2 = null;
+ public CPOperand input3 = null;
+ public CPOperand output = null;
private MapmmChainSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, ChainType type, String opcode, String istr) {
super(SPType.MAPMMCHAIN, op, opcode, istr);
- _input1 = in1;
- _input2 = in2;
- _output = out;
+ input1 = in1;
+ input2 = in2;
+ output = out;
_chainType = type;
}
private MapmmChainSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, ChainType type, String opcode, String istr) {
super(SPType.MAPMMCHAIN, op, opcode, istr);
- _input1 = in1;
- _input2 = in2;
- _input3 = in3;
- _output = out;
+ input1 = in1;
+ input2 = in2;
+ input3 = in3;
+ output = out;
_chainType = type;
}
@@ -102,8 +102,8 @@
SparkExecutionContext sec = (SparkExecutionContext)ec;
//get rdd and broadcast inputs
- JavaPairRDD<MatrixIndexes,MatrixBlock> inX = sec.getBinaryMatrixBlockRDDHandleForVariable( _input1.getName() );
- PartitionedBroadcast<MatrixBlock> inV = sec.getBroadcastForVariable( _input2.getName() );
+ JavaPairRDD<MatrixIndexes,MatrixBlock> inX = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
+ PartitionedBroadcast<MatrixBlock> inV = sec.getBroadcastForVariable( input2.getName() );
//execute mapmmchain (guaranteed to have single output block)
MatrixBlock out = null;
@@ -112,21 +112,21 @@
out = RDDAggregateUtils.sumStable(tmp);
}
else { // ChainType.XtwXv / ChainType.XtXvy
- PartitionedBroadcast<MatrixBlock> inW = sec.getBroadcastForVariable( _input3.getName() );
+ PartitionedBroadcast<MatrixBlock> inW = sec.getBroadcastForVariable( input3.getName() );
JavaRDD<MatrixBlock> tmp = inX.map(new RDDMapMMChainFunction2(inV, inW, _chainType));
out = RDDAggregateUtils.sumStable(tmp);
}
//put output block into symbol table (no lineage because single block)
//this also includes implicit maintenance of matrix characteristics
- sec.setMatrixOutput(_output.getName(), out);
+ sec.setMatrixOutput(output.getName(), out);
}
@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
CPOperand chainT = new CPOperand(_chainType.name(), Types.ValueType.INT64, Types.DataType.SCALAR, true);
- return Pair.of(_output.getName(), new LineageItem(getOpcode(),
- LineageItemUtils.getLineage(ec, _input1, _input2, _input3, chainT)));
+ return Pair.of(output.getName(), new LineageItem(getOpcode(),
+ LineageItemUtils.getLineage(ec, input1, input2, input3, chainT)));
}
/**
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index f4a134e..dac25c3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -197,6 +197,8 @@
return _aggtype;
}
+ public CacheType getCacheType() { return _type; }
+
private static boolean preservesPartitioning(DataCharacteristics mcIn, CacheType type )
{
if( type == CacheType.LEFT )
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
index e320105..ac2d8f4 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
@@ -78,6 +78,10 @@
_type = type;
}
+ public LixCacheType getLixType() {
+ return _type;
+ }
+
@Override
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/PmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/PmmSPInstruction.java
index cbaf347..723db9f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/PmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/PmmSPInstruction.java
@@ -98,6 +98,8 @@
updateBinaryMMOutputDataCharacteristics(sec, false);
}
+ public CacheType getCacheType() { return _type; }
+
private static class RDDPMMFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = -1696560050436469140L;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
index abc30b0..de01a71 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
@@ -108,9 +108,8 @@
CPOperand desc = new CPOperand(parts[3]);
CPOperand ixret = new CPOperand(parts[4]);
boolean bSortIndInMem = false;
-
- if(parts.length > 5)
- bSortIndInMem = Boolean.parseBoolean(parts[6]);
+
+ bSortIndInMem = Boolean.parseBoolean(parts[6]);
return new ReorgSPInstruction(new ReorgOperator(new SortIndex(1,false,false)),
in, col, desc, ixret, out, opcode, bSortIndInMem, str);
@@ -249,7 +248,11 @@
}
}
- private static class RDDDiagV2MFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
+ public CPOperand getIxRet() {
+ return _ixret;
+ }
+
+ private static class RDDDiagV2MFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = 31065772250744103L;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/Tsmm2SPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/Tsmm2SPInstruction.java
index b0460d5..9830a30 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/Tsmm2SPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/Tsmm2SPInstruction.java
@@ -116,7 +116,11 @@
}
}
- private static class RDDTSMM2Function implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
+ public MMTSJType getMMTSJType() {
+ return _type;
+ }
+
+ private static class RDDTSMM2Function implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = 2935770425858019666L;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
index dd6ddb5..45183cf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
@@ -93,6 +93,11 @@
}
}
+ public MMTSJType getMMTSJType()
+ {
+ return _type;
+ }
+
private static class RDDTSMMFunction implements Function<Tuple2<MatrixIndexes,MatrixBlock>, MatrixBlock>
{
private static final long serialVersionUID = 2935770425858019666L;
diff --git a/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
index 85d8012..079a08c 100644
--- a/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
+++ b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
@@ -19,8 +19,6 @@
package org.apache.sysds.utils.stats;
-import org.apache.commons.lang3.function.TriFunction;
-
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Comparator;
diff --git a/src/test/java/org/apache/sysds/test/component/resource/CPCostUtilsTest.java b/src/test/java/org/apache/sysds/test/component/resource/CPCostUtilsTest.java
new file mode 100644
index 0000000..303f2be
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/resource/CPCostUtilsTest.java
@@ -0,0 +1,581 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.resource;
+
+import org.apache.sysds.resource.cost.CPCostUtils;
+import org.apache.sysds.resource.cost.VarStats;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class CPCostUtilsTest {
+
+ @Test
+ public void testUnaryNotInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("!", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUnaryIsnaInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("isna", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUnaryIsnanInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("isnan", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUnaryIsinfInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("isinf", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUnaryCeilInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("ceil", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUnaryFloorInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("floor", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testAbsInstNFLOPDefaultSparsity() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("abs", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testAbsInstNFLOPSparse() {
+ long expectedValue = (long) (0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("abs", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testRoundInstNFLOPDefaultSparsity() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("round", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testRoundInstNFLOPSparse() {
+ long expectedValue = (long) (0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("round", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testSignInstNFLOPDefaultSparsity() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("sign", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testSignInstNFLOPSparse() {
+ long expectedValue = (long) (0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("sign", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testSpropInstNFLOPDefaultSparsity() {
+ long expectedValue = 2 * 1000 * 1000;
+ testUnaryInstNFLOP("sprop", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testSpropInstNFLOPSparse() {
+ long expectedValue = (long) (2 * 0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("sprop", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testSqrtInstNFLOPDefaultSparsity() {
+ long expectedValue = 2 * 1000 * 1000;
+ testUnaryInstNFLOP("sqrt", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testSqrtInstNFLOPSparse() {
+ long expectedValue = (long) (2 * 0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("sqrt", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testExpInstNFLOPDefaultSparsity() {
+ long expectedValue = 18 * 1000 * 1000;
+ testUnaryInstNFLOP("exp", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testExpInstNFLOPSparse() {
+ long expectedValue = (long) (18 * 0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("exp", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testSigmoidInstNFLOPDefaultSparsity() {
+ long expectedValue = 21 * 1000 * 1000;
+ testUnaryInstNFLOP("sigmoid", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testSigmoidInstNFLOPSparse() {
+ long expectedValue = (long) (21 * 0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("sigmoid", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testPlogpInstNFLOPDefaultSparsity() {
+ long expectedValue = 32 * 1000 * 1000;
+ testUnaryInstNFLOP("plogp", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testPlogpInstNFLOPSparse() {
+ long expectedValue = (long) (32 * 0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("plogp", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testPrintInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("print", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testAssertInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("assert", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testSinInstNFLOPDefaultSparsity() {
+ long expectedValue = 18 * 1000 * 1000;
+ testUnaryInstNFLOP("sin", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testSinInstNFLOPSparse() {
+ long expectedValue = (long) (18 * 0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("sin", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testCosInstNFLOPDefaultSparsity() {
+ long expectedValue = 22 * 1000 * 1000;
+ testUnaryInstNFLOP("cos", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testCosInstNFLOPSparse() {
+ long expectedValue = (long) (22 * 0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("cos", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testTanInstNFLOPDefaultSparsity() {
+ long expectedValue = 42 * 1000 * 1000;
+ testUnaryInstNFLOP("tan", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testTanInstNFLOPSparse() {
+ long expectedValue = (long) (42 * 0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("tan", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testAsinInstNFLOP() {
+ long expectedValue = 93 * 1000 * 1000;
+ testUnaryInstNFLOP("asin", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testSinhInstNFLOP() {
+ long expectedValue = 93 * 1000 * 1000;
+ testUnaryInstNFLOP("sinh", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testAcosInstNFLOP() {
+ long expectedValue = 103 * 1000 * 1000;
+ testUnaryInstNFLOP("acos", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testCoshInstNFLOP() {
+ long expectedValue = 103 * 1000 * 1000;
+ testUnaryInstNFLOP("cosh", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testAtanInstNFLOP() {
+ long expectedValue = 40 * 1000 * 1000;
+ testUnaryInstNFLOP("atan", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testTanhInstNFLOP() {
+ long expectedValue = 40 * 1000 * 1000;
+ testUnaryInstNFLOP("tanh", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUcumkPlusInstNFLOPDefaultSparsity() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("ucumk+", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUcumkPlusInstNFLOPSparse() {
+ long expectedValue = (long) (0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("ucumk+", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUcumMinInstNFLOPDefaultSparsity() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("ucummin", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUcumMinInstNFLOPSparse() {
+ long expectedValue = (long) (0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("ucummin", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUcumMaxInstNFLOPDefaultSparsity() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("ucummax", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUcumMaxInstNFLOPSparse() {
+ long expectedValue = (long) (0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("ucummax", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUcumMultInstNFLOPDefaultSparsity() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("ucum*", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUcumMultInstNFLOPSparse() {
+ long expectedValue = (long) (0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("ucum*", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUcumkPlusMultInstNFLOPDefaultSparsity() {
+ long expectedValue = 2 * 1000 * 1000;
+ testUnaryInstNFLOP("ucumk+*", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testUcumkPlusMultInstNFLOPSparse() {
+ long expectedValue = (long) (2 * 0.5 * 1000 * 1000);
+ testUnaryInstNFLOP("ucumk+*", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testStopInstNFLOP() {
+ long expectedValue = 0;
+ testUnaryInstNFLOP("stop", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testTypeofInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testUnaryInstNFLOP("typeof", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testInverseInstNFLOPDefaultSparsity() {
+ long expectedValue = (long) ((4.0 / 3.0) * (1000 * 1000) * (1000 * 1000) * (1000 * 1000));
+ testUnaryInstNFLOP("inverse", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testInverseInstNFLOPSparse() {
+ long expectedValue = (long) ((4.0 / 3.0) * (1000 * 1000) * (0.5 * 1000 * 1000) * (0.5 *1000 * 1000));
+ testUnaryInstNFLOP("inverse", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testCholeskyInstNFLOPDefaultSparsity() {
+ long expectedValue = (long) ((1.0 / 3.0) * (1000 * 1000) * (1000 * 1000) * (1000 * 1000));
+ testUnaryInstNFLOP("cholesky", -1, -1, expectedValue);
+ }
+
+ @Test
+ public void testCholeskyInstNFLOPSparse() {
+ long expectedValue = (long) ((1.0 / 3.0) * (1000 * 1000) * (0.5 * 1000 * 1000) * (0.5 *1000 * 1000));
+ testUnaryInstNFLOP("cholesky", 0.5, 0.5, expectedValue);
+ }
+
+ @Test
+ public void testLogInstNFLOP() {
+ long expectedValue = 32 * 1000 * 1000;
+ testBuiltinInstNFLOP("log", -1, expectedValue);
+ }
+
+ @Test
+ public void testLogNzInstNFLOPDefaultSparsity() {
+ long expectedValue = 32 * 1000 * 1000;
+ testBuiltinInstNFLOP("log_nz", -1, expectedValue);
+ }
+
+ @Test
+ public void testLogNzInstNFLOPSparse() {
+ long expectedValue = (long) (32 * 0.5 * 1000 * 1000);
+ testBuiltinInstNFLOP("log_nz", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testNrowInstNFLOP() {
+ long expectedValue = 10L;
+ testAggregateUnaryInstNFLOP("nrow", expectedValue);
+ }
+
+ @Test
+ public void testNcolInstNFLOP() {
+ long expectedValue = 10L;
+ testAggregateUnaryInstNFLOP("ncol", expectedValue);
+ }
+
+ @Test
+ public void testLengthInstNFLOP() {
+ long expectedValue = 10L;
+ testAggregateUnaryInstNFLOP("length", expectedValue);
+ }
+
+ @Test
+ public void testExistsInstNFLOP() {
+ long expectedValue = 10L;
+ testAggregateUnaryInstNFLOP("exists", expectedValue);
+ }
+
+ @Test
+ public void testLineageInstNFLOP() {
+ long expectedValue = 10L;
+ testAggregateUnaryInstNFLOP("lineage", expectedValue);
+ }
+
+ @Test
+ public void testUakInstNFLOP() {
+ long expectedValue = 4 * 1000 * 1000;
+ testAggregateUnaryInstNFLOP("uak+", expectedValue);
+ }
+
+ @Test
+ public void testUarkInstNFLOP() {
+ long expectedValue = 4L * 2000 * 2000;
+ testAggregateUnaryRowInstNFLOP("uark+", -1, expectedValue);
+ testAggregateUnaryRowInstNFLOP("uark+", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUackInstNFLOP() {
+ long expectedValue = 4L * 3000 * 3000;
+ testAggregateUnaryColInstNFLOP("uack+", -1, expectedValue);
+ testAggregateUnaryColInstNFLOP("uack+", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUasqkInstNFLOP() {
+ long expectedValue = 5L * 1000 * 1000;
+ testAggregateUnaryInstNFLOP("uasqk+", expectedValue);
+ }
+
+ @Test
+ public void testUarsqkInstNFLOP() {
+ long expectedValue = 5L * 2000 * 2000;
+ testAggregateUnaryRowInstNFLOP("uarsqk+", -1, expectedValue);
+ testAggregateUnaryRowInstNFLOP("uarsqk+", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUacsqkInstNFLOP() {
+ long expectedValue = 5L * 3000 * 3000;
+ testAggregateUnaryColInstNFLOP("uacsqk+", -1, expectedValue);
+ testAggregateUnaryColInstNFLOP("uacsqk+", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUameanInstNFLOP() {
+ long expectedValue = 7L * 1000 * 1000;
+ testAggregateUnaryInstNFLOP("uamean", expectedValue);
+ }
+
+ @Test
+ public void testUarmeanInstNFLOP() {
+ long expectedValue = 7L * 2000 * 2000;
+ testAggregateUnaryRowInstNFLOP("uarmean", -1, expectedValue);
+ testAggregateUnaryRowInstNFLOP("uarmean", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUacmeanInstNFLOP() {
+ long expectedValue = 7L * 3000 * 3000;
+ testAggregateUnaryColInstNFLOP("uacmean", -1, expectedValue);
+ testAggregateUnaryColInstNFLOP("uacmean", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUavarInstNFLOP() {
+ long expectedValue = 14L * 1000 * 1000;
+ testAggregateUnaryInstNFLOP("uavar", expectedValue);
+ }
+
+ @Test
+ public void testUarvarInstNFLOP() {
+ long expectedValue = 14L * 2000 * 2000;
+ testAggregateUnaryRowInstNFLOP("uarvar", -1, expectedValue);
+ testAggregateUnaryRowInstNFLOP("uarvar", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUacvarInstNFLOP() {
+ long expectedValue = 14L * 3000 * 3000;
+ testAggregateUnaryColInstNFLOP("uacvar", -1, expectedValue);
+ testAggregateUnaryColInstNFLOP("uacvar", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUamaxInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testAggregateUnaryInstNFLOP("uamax", expectedValue);
+ }
+
+ @Test
+ public void testUarmaxInstNFLOP() {
+ long expectedValue = 2000 * 2000;
+ testAggregateUnaryRowInstNFLOP("uarmax", -1, expectedValue);
+ testAggregateUnaryRowInstNFLOP("uarmax", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUarimaxInstNFLOP() {
+ long expectedValue = 2000 * 2000;
+ testAggregateUnaryRowInstNFLOP("uarimax", -1, expectedValue);
+ testAggregateUnaryRowInstNFLOP("uarimax", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUacmaxInstNFLOP() {
+ long expectedValue = 3000 * 3000;
+ testAggregateUnaryColInstNFLOP("uacmax", -1, expectedValue);
+ testAggregateUnaryColInstNFLOP("uacmax", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUaminInstNFLOP() {
+ long expectedValue = 1000 * 1000;
+ testAggregateUnaryInstNFLOP("uamin", expectedValue);
+ }
+
+ @Test
+ public void testUarminInstNFLOP() {
+ long expectedValue = 2000 * 2000;
+ testAggregateUnaryRowInstNFLOP("uarmin", -1, expectedValue);
+ testAggregateUnaryRowInstNFLOP("uarmin", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUariminInstNFLOP() {
+ long expectedValue = 2000 * 2000;
+ testAggregateUnaryRowInstNFLOP("uarimin", -1, expectedValue);
+ testAggregateUnaryRowInstNFLOP("uarimin", 0.5, expectedValue);
+ }
+
+ @Test
+ public void testUacminInstNFLOP() {
+ long expectedValue = 3000 * 3000;
+ testAggregateUnaryColInstNFLOP("uacmin", -1, expectedValue);
+ testAggregateUnaryColInstNFLOP("uacmin", 0.5, expectedValue);
+ }
+
+ // HELPERS
+
+ private void testUnaryInstNFLOP(String opcode, double sparsityIn, double sparsityOut, long expectedNFLOP) {
+ long nnzIn = sparsityIn < 0? -1 : (long) (sparsityIn * 1000 * 1000);
+ VarStats input = generateVarStatsMatrix("_mVar1", 1000, 1000, nnzIn);
+ long nnzOut = sparsityOut < 0? -1 : (long) (sparsityOut * 1000 * 1000);
+ VarStats output = generateVarStatsMatrix("_mVar2", 1000, 1000, nnzOut);
+
+ long result = CPCostUtils.getInstNFLOP(CPType.Unary, opcode, output, input);
+ assertEquals(expectedNFLOP, result);
+ }
+
+ private void testBuiltinInstNFLOP(String opcode, double sparsityIn, long expectedNFLOP) {
+ long nnz = sparsityIn < 0? -1 : (long) (sparsityIn * 1000 * 1000);
+ VarStats input = generateVarStatsMatrix("_mVar1", 1000, 1000, nnz);
+ VarStats output = generateVarStatsMatrix("_mVar2", 1000, 1000, -1);
+
+ long result = CPCostUtils.getInstNFLOP(CPType.Unary, opcode, output, input);
+ assertEquals(expectedNFLOP, result);
+ }
+
+ private void testAggregateUnaryInstNFLOP(String opcode, long expectedNFLOP) {
+ VarStats input = generateVarStatsMatrix("_mVar1", 1000, 1000, -1);
+ VarStats output = generateVarStatsScalarLiteral("_Var2");
+
+ long result = CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, input);
+ assertEquals(expectedNFLOP, result);
+ }
+
+ private void testAggregateUnaryRowInstNFLOP(String opcode, double sparsityOut, long expectedNFLOP) {
+ VarStats input = generateVarStatsMatrix("_mVar1", 2000, 1000, -1);
+ long nnzOut = sparsityOut < 0? -1 : (long) (sparsityOut * 2000);
+ VarStats output = generateVarStatsMatrix("_mVar2", 2000, 1, nnzOut);
+
+ long result = CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, input);
+ assertEquals(expectedNFLOP, result);
+ }
+
+ private void testAggregateUnaryColInstNFLOP(String opcode, double sparsityOut, long expectedNFLOP) {
+ VarStats input = generateVarStatsMatrix("_mVar1", 1000, 3000, -1);
+ long nnzOut = sparsityOut < 0? -1 : (long) (sparsityOut * 3000);
+ VarStats output = generateVarStatsMatrix("_mVar2", 1, 3000, nnzOut);
+
+ long result = CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, input);
+ assertEquals(expectedNFLOP, result);
+ }
+
+ private VarStats generateVarStatsMatrix(String name, long rows, long cols, long nnz) {
+ MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, nnz);
+ return new VarStats(name, mc);
+ }
+
+ private VarStats generateVarStatsScalarLiteral(String nameOrValue) {
+ return new VarStats(nameOrValue, null);
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java b/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java
index 6a17e4f..f7ceaf9 100644
--- a/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java
+++ b/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java
@@ -23,6 +23,9 @@
import java.io.FileReader;
import java.util.HashMap;
+import org.apache.sysds.resource.CloudInstance;
+import org.apache.sysds.resource.ResourceCompiler;
+import org.apache.sysds.utils.Explain;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.api.DMLScript;
@@ -36,33 +39,83 @@
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
+import scala.Tuple2;
+
+import static org.apache.sysds.test.component.resource.TestingUtils.getSimpleCloudInstanceMap;
public class CostEstimatorTest extends AutomatedTestBase
{
+ private static final boolean DEBUG_MODE = true;
private static final String TEST_DIR = "component/resource/";
private static final String HOME = SCRIPT_DIR + TEST_DIR;
private static final String TEST_CLASS_DIR = TEST_DIR + CostEstimatorTest.class.getSimpleName() + "/";
+ private static final int DEFAULT_NUM_EXECUTORS = 4;
+ private static final HashMap<String, CloudInstance> INSTANCE_MAP = getSimpleCloudInstanceMap();
@Override
public void setUp() {}
-
- @Test
- public void testKMeans() { runTest("Algorithm_KMeans.dml"); }
@Test
- public void testL2SVM() { runTest("Algorithm_L2SVM.dml"); }
+ public void testL2SVMSingleNode() { runTest("Algorithm_L2SVM.dml", "m5.xlarge", null); }
@Test
- public void testLinreg() { runTest("Algorithm_Linreg.dml"); }
+ public void testL2SVMHybrid() { runTest("Algorithm_L2SVM.dml", "m5.xlarge", "m5.xlarge"); }
@Test
- public void testMLogreg() { runTest("Algorithm_MLogreg.dml"); }
+ public void testLinregSingleNode() { runTest("Algorithm_Linreg.dml", "m5.xlarge", null); }
@Test
- public void testPCA() { runTest("Algorithm_PCA.dml"); }
+ public void testLinregHybrid() { runTest("Algorithm_Linreg.dml", "m5.xlarge", "m5.xlarge"); }
-
- private void runTest( String scriptFilename ) {
+ @Test
+ public void testPCASingleNode() { runTest("Algorithm_PCA.dml", "m5.xlarge", null); }
+ @Test
+ public void testPCAHybrid() { runTest("Algorithm_PCA.dml", "m5.xlarge", "m5.xlarge"); }
+
+ @Test
+ public void testPNMFSingleNode() { runTest("Algorithm_PNMF.dml", "m5.xlarge", null); }
+
+ @Test
+ public void testPNMFHybrid() { runTest("Algorithm_PNMF.dml", "m5.xlarge", "m5.xlarge"); }
+
+ @Test
+ public void testReadAndWriteSingleNode() {
+ Tuple2<String, String> arg1 = new Tuple2<>("$fileA", HOME+"data/A.csv");
+ Tuple2<String, String> arg2 = new Tuple2<>("$fileA_Csv", HOME+"data/A_copy.csv");
+ Tuple2<String, String> arg3 = new Tuple2<>("$fileA_Text", HOME+"data/A_copy_text.text");
+ runTest("ReadAndWrite.dml", "m5.xlarge", null, arg1, arg2, arg3);
+ }
+
+ @Test
+ public void testReadAndWriteHybrid() {
+ Tuple2<String, String> arg1 = new Tuple2<>("$fileA", HOME+"data/A.csv");
+ Tuple2<String, String> arg2 = new Tuple2<>("$fileA_Csv", HOME+"data/A_copy.csv");
+ Tuple2<String, String> arg3 = new Tuple2<>("$fileA_Text", HOME+"data/A_copy_text.text");
+ runTest("ReadAndWrite.dml", "c5.xlarge", "m5.xlarge", arg1, arg2, arg3);
+ }
+
+
+
+ @SafeVarargs
+ private void runTest(String scriptFilename, String driverInstance, String executorInstance, Tuple2<String, String>...args) {
+ CloudInstance driver;
+ CloudInstance executor;
+ try {
+ // setting driver node is required
+ driver = INSTANCE_MAP.get(driverInstance);
+ ResourceCompiler.setDriverConfigurations(driver.getMemory(), driver.getVCPUs());
+ // setting executor node is optional: no executor -> single node execution
+ if (executorInstance == null) {
+ executor = null;
+ ResourceCompiler.setSingleNodeExecution();
+ } else {
+ executor = INSTANCE_MAP.get(executorInstance);
+ ResourceCompiler.setExecutorConfigurations(DEFAULT_NUM_EXECUTORS, executor.getMemory(), executor.getVCPUs());
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ throw new RuntimeException("Resource initialization for teh current test failed.");
+ }
try
{
// Tell the superclass about the name of this test, so that the superclass can
@@ -78,8 +131,11 @@
ConfigurationManager.setLocalConfig(conf);
String dmlScriptString="";
+ // assign arguments
HashMap<String, String> argVals = new HashMap<>();
-
+ for (Tuple2<String, String> arg : args)
+ argVals.put(arg._1, arg._2);
+
//read script
try( BufferedReader in = new BufferedReader(new FileReader(HOME + scriptFilename)) ) {
String s1 = null;
@@ -97,13 +153,15 @@
dmlt.rewriteHopsDAG(prog);
dmlt.constructLops(prog);
Program rtprog = dmlt.getRuntimeProgram(prog, ConfigurationManager.getDMLConfig());
-
- //check error-free cost estimation and meaningful result
- Assert.assertTrue(CostEstimator.estimateExecutionTime(rtprog) > 0);
+ if (DEBUG_MODE) System.out.println(Explain.explain(rtprog));
+ double timeCost = CostEstimator.estimateExecutionTime(rtprog, driver, executor);
+ if (DEBUG_MODE) System.out.println("Estimated execution time: " + timeCost + " seconds.");
+ // check error-free cost estimation and meaningful result
+ Assert.assertTrue(timeCost > 0);
}
- catch(Exception ex) {
- ex.printStackTrace();
- //TODO throw new RuntimeException(ex);
+ catch(Exception e) {
+ e.printStackTrace();
+ throw new RuntimeException("Error at parsing the return program for cost estimation");
}
}
}
diff --git a/src/test/java/org/apache/sysds/test/component/resource/EnumeratorTests.java b/src/test/java/org/apache/sysds/test/component/resource/EnumeratorTests.java
index 4555c4b..e333264 100644
--- a/src/test/java/org/apache/sysds/test/component/resource/EnumeratorTests.java
+++ b/src/test/java/org/apache/sysds/test/component/resource/EnumeratorTests.java
@@ -378,11 +378,10 @@
List<CloudInstance> expectedInstances = new ArrayList<>(Arrays.asList(
- instances.get("c5.xlarge"),
- instances.get("m5.xlarge")
+ instances.get("c5.xlarge")
));
// expected solution pool with 0 executors (number executors = 0, executors and executorInstance being null)
- // each solution having one of the available instances as driver node
+ // with a single solution -> the cheapest instance for the driver
Assert.assertEquals(expectedInstances.size(), actualSolutionPoolGB.size());
Assert.assertEquals(expectedInstances.size(), actualSolutionPoolIB.size());
for (int i = 0; i < expectedInstances.size(); i++) {
@@ -428,8 +427,8 @@
gridBasedEnumerator.processing();
SolutionPoint solution = gridBasedEnumerator.postprocessing();
- // expected m5.xlarge since it is the cheaper
- Assert.assertEquals("m5.xlarge", solution.driverInstance.getInstanceName());
+ // expected c5.xlarge since it is the cheaper
+ Assert.assertEquals("c5.xlarge", solution.driverInstance.getInstanceName());
// expected no executor nodes since tested for a 'zero' program
Assert.assertEquals(0, solution.numberExecutors);
}
@@ -466,8 +465,8 @@
gridBasedEnumerator.processing();
SolutionPoint solution = gridBasedEnumerator.postprocessing();
- // expected m5.xlarge since it is the cheaper
- Assert.assertEquals("m5.xlarge", solution.driverInstance.getInstanceName());
+ // expected c5.xlarge since it is the cheaper
+ Assert.assertEquals("c5.xlarge", solution.driverInstance.getInstanceName());
// expected no executor nodes since tested for a 'zero' program
Assert.assertEquals(0, solution.numberExecutors);
}
diff --git a/src/test/java/org/apache/sysds/test/component/resource/InstructionsCostEstimatorTest.java b/src/test/java/org/apache/sysds/test/component/resource/InstructionsCostEstimatorTest.java
new file mode 100644
index 0000000..fe2bae9
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/resource/InstructionsCostEstimatorTest.java
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.resource;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.resource.CloudInstance;
+import org.apache.sysds.resource.ResourceCompiler;
+import org.apache.sysds.resource.cost.CostEstimationException;
+import org.apache.sysds.resource.cost.CostEstimator;
+import org.apache.sysds.resource.cost.RDDStats;
+import org.apache.sysds.resource.cost.VarStats;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction;
+import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.SPInstruction;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+import static org.apache.sysds.resource.CloudUtils.GBtoBytes;
+import static org.apache.sysds.test.component.resource.TestingUtils.getSimpleCloudInstanceMap;
+
+public class InstructionsCostEstimatorTest {
+ private static final HashMap<String, CloudInstance> instanceMap = getSimpleCloudInstanceMap();
+
+ private CostEstimator estimator;
+
+ @Before
+ public void setup() {
+ ResourceCompiler.setDriverConfigurations(GBtoBytes(8), 4);
+ ResourceCompiler.setExecutorConfigurations(4, GBtoBytes(8), 4);
+ estimator = new CostEstimator(new Program(), instanceMap.get("m5.xlarge"), instanceMap.get("m5.xlarge"));
+ }
+
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ // Tests for CP Instructions //
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+ @Test
+ public void createvarMatrixVariableCPInstructionTest() throws CostEstimationException {
+ String instDefinition = "CP°createvar°testVar°testOutputFile°false°MATRIX°binary°100°100°1000°10000°COPY";
+ VariableCPInstruction inst = VariableCPInstruction.parseInstruction(instDefinition);
+ testGettingTimeEstimateForCPInst(estimator, null, inst, 0);
+ // test the proper maintainCPInstVariableStatistics functionality
+ estimator.maintainStats(inst);
+ VarStats actualStats = estimator.getStats("testVar");
+ Assert.assertNotNull(actualStats);
+ Assert.assertEquals(10000, actualStats.getCells());
+ }
+
+ @Test
+ public void createvarFrameVariableCPInstructionTest() throws CostEstimationException {
+ String instDefinition = "CP°createvar°testVar°testOutputFile°false°FRAME°binary°100°100°1000°10000°COPY";
+ VariableCPInstruction inst = VariableCPInstruction.parseInstruction(instDefinition);
+ testGettingTimeEstimateForCPInst(estimator, null, inst, 0);
+ // test the proper maintainCPInstVariableStatistics functionality
+ estimator.maintainStats(inst);
+ VarStats actualStats = estimator.getStats("testVar");
+ Assert.assertNotNull(actualStats);
+ Assert.assertEquals(10000, actualStats.getCells());
+ }
+
+ @Test
+ public void createvarInvalidVariableCPInstructionTest() throws CostEstimationException {
+ String instDefinition = "CP°createvar°testVar°testOutputFile°false°TENSOR°binary°100°100°1000°10000°copy";
+ VariableCPInstruction inst = VariableCPInstruction.parseInstruction(instDefinition);
+ try {
+ estimator.maintainStats(inst);
+ testGettingTimeEstimateForCPInst(estimator, null, inst, 0);
+ Assert.fail("Tensor is not supported by the cost estimator");
+ } catch (RuntimeException e) {
+ // needed catch block to assert that RuntimeException has been thrown
+ }
+ }
+
+ @Test
+ public void randCPInstructionTest() throws CostEstimationException {
+ HashMap<String, VarStats> inputStats = new HashMap<>();
+ inputStats.put("matrixVar", generateStats("matrixVar", 10000, 10000, -1));
+ inputStats.put("outputVar", generateStats("outputVar", 10000, 10000, -1));
+
+ String instDefinition = "CP°+°scalarVar·SCALAR·FP64·false°matrixVar·MATRIX·FP64°outputVar·MATRIX·FP64";
+ BinaryCPInstruction inst = BinaryCPInstruction.parseInstruction(instDefinition);
+ testGettingTimeEstimateForCPInst(estimator, inputStats, inst, -1);
+ }
+
+ @Test
+ public void randCPInstructionExceedMemoryBudgetTest() {
+ HashMap<String, VarStats> inputStats = new HashMap<>();
+ inputStats.put("matrixVar", generateStats("matrixVar", 1000000, 1000000, -1));
+ inputStats.put("outputVar", generateStats("outputVar", 1000000, 1000000, -1));
+
+ String instDefinition = "CP°+°scalarVar·SCALAR·FP64·false°matrixVar·MATRIX·FP64°outputVar·MATRIX·FP64";
+ BinaryCPInstruction inst = BinaryCPInstruction.parseInstruction(instDefinition);
+ try {
+ testGettingTimeEstimateForCPInst(estimator, inputStats, inst, -1);
+ Assert.fail("CostEstimationException should have been thrown for the given data size and instruction");
+ } catch (CostEstimationException e) {
+ // needed catch block to assert that CostEstimationException has been thrown
+ }
+ }
+
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ // Tests for Spark Instructions //
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+ @Test
+ public void plusBinaryMatrixMatrixSpInstructionTest() throws CostEstimationException {
+ HashMap<String, VarStats> inputStats = new HashMap<>();
+ inputStats.put("matrixVar", generateStatsWithRdd("matrixVar", 1000000,1000000, 500000000000L));
+ inputStats.put("outputVar", generateStats("outputVar", 1000000,1000000, -1));
+
+ String instDefinition = "SPARK°+°scalarVar·SCALAR·FP64·false°matrixVar·MATRIX·FP64°outputVar·MATRIX·FP64";
+ BinarySPInstruction inst = BinarySPInstruction.parseInstruction(instDefinition);
+ testGettingTimeEstimateForSparkInst(estimator, inputStats, inst, "outputVar", -1);
+ }
+
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ // Helper methods for testing Instructions //
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+ private VarStats generateStats(String name, long m, long n, long nnz) {
+ MatrixCharacteristics mc = new MatrixCharacteristics(m, n, nnz);
+ VarStats ret = new VarStats(name, mc);
+ long size = OptimizerUtils.estimateSizeExactSparsity(ret.getM(), ret.getN(), ret.getSparsity());
+ ret.setAllocatedMemory(size);
+ return ret;
+ }
+
+ private VarStats generateStatsWithRdd(String name, long m, long n, long nnz) {
+ MatrixCharacteristics mc = new MatrixCharacteristics(m, n, nnz);
+ VarStats stats = new VarStats(name, mc);
+ RDDStats rddStats = new RDDStats(stats);
+ stats.setRddStats(rddStats);
+ return stats;
+ }
+
+ private static void testGettingTimeEstimateForCPInst(
+ CostEstimator estimator,
+ HashMap<String, VarStats> inputStats,
+ CPInstruction targetInstruction,
+ double expectedCost
+ ) throws CostEstimationException {
+ if (inputStats != null)
+ estimator.putStats(inputStats);
+ double actualCost = estimator.getTimeEstimateInst(targetInstruction);
+
+ if (expectedCost < 0) {
+ // check error-free cost estimation and meaningful result
+ Assert.assertTrue(actualCost > 0);
+ } else {
+ // check error-free cost estimation and exact result
+ Assert.assertEquals(expectedCost, actualCost, 0.0);
+ }
+ }
+
+ private static void testGettingTimeEstimateForSparkInst(
+ CostEstimator estimator,
+ HashMap<String, VarStats> inputStats,
+ SPInstruction targetInstruction,
+ String outputVar,
+ double expectedCost
+ ) throws CostEstimationException {
+ if (inputStats != null)
+ estimator.putStats(inputStats);
+ double actualCost = estimator.getTimeEstimateInst(targetInstruction);
+ RDDStats outputRDD = estimator.getStats(outputVar).getRddStats();
+ if (outputRDD.isCollected()) {
+ // cost directly returned
+ if (expectedCost < 0) {
+ // check error-free cost estimation and meaningful result
+ Assert.assertTrue(actualCost > 0);
+ } else {
+ // check error-free cost estimation and exact result
+ Assert.assertEquals(expectedCost, actualCost, 0.0);
+ }
+ } else {
+ // cost saved in RDD statistics
+ double sparkCost = outputRDD.getCost();
+ if (expectedCost < 0) {
+ // check error-free cost estimation and meaningful result
+ Assert.assertTrue(sparkCost > 0);
+ } else {
+ // check error-free cost estimation and exact result
+ Assert.assertEquals(expectedCost, sparkCost, 0.0);
+ }
+ }
+ }
+
+}
diff --git a/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java b/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java
index c1a68b0..f801982 100644
--- a/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java
+++ b/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java
@@ -56,10 +56,11 @@
@Test
public void testSetDriverConfigurations() {
- long expectedMemory = 1024*1024*1024; // 1GB
+ long nodeMemory = 1024*1024*1024; // 1GB
+ long expectedMemory = (long) (0.9 * nodeMemory);
int expectedThreads = 4;
- ResourceCompiler.setDriverConfigurations(expectedMemory, expectedThreads);
+ ResourceCompiler.setDriverConfigurations(nodeMemory, expectedThreads);
Assert.assertEquals(expectedMemory, InfrastructureAnalyzer.getLocalMaxMemory());
Assert.assertEquals(expectedThreads, InfrastructureAnalyzer.getLocalParallelism());
@@ -170,9 +171,9 @@
runTest(precompiledProgram, expectedProgram, 8L*1024*1024*1024, 0, -1, "ba+*", false);
ResourceCompiler.setDriverConfigurations(16L*1024*1024*1024, driverThreads);
- ResourceCompiler.setExecutorConfigurations(2, 1024*1024*1024, executorThreads);
+ ResourceCompiler.setExecutorConfigurations(4, 1024*1024*1024, executorThreads);
expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml", nvargs);
- runTest(precompiledProgram, expectedProgram, 16L*1024*1024*1024, 2, 1024*1024*1024, "ba+*", false);
+ runTest(precompiledProgram, expectedProgram, 16L*1024*1024*1024, 4, 1024*1024*1024, "ba+*", false);
ResourceCompiler.setDriverConfigurations(4L*1024*1024*1024, driverThreads);
ResourceCompiler.setExecutorConfigurations(2, 4L*1024*1024*1024, executorThreads);
diff --git a/src/test/java/org/apache/sysds/test/component/resource/TestingUtils.java b/src/test/java/org/apache/sysds/test/component/resource/TestingUtils.java
index 035fac6..38dde48 100644
--- a/src/test/java/org/apache/sysds/test/component/resource/TestingUtils.java
+++ b/src/test/java/org/apache/sysds/test/component/resource/TestingUtils.java
@@ -47,10 +47,10 @@
public static HashMap<String, CloudInstance> getSimpleCloudInstanceMap() {
HashMap<String, CloudInstance> instanceMap = new HashMap<>();
// fill the map wsearchStrategyh enough cloud instances to allow testing all search space dimension searchStrategyerations
- instanceMap.put("m5.xlarge", new CloudInstance("m5.xlarge", GBtoBytes(16), 4, 0.5, 0.0, 143.75, 160, 1.5));
- instanceMap.put("m5.2xlarge", new CloudInstance("m5.2xlarge", GBtoBytes(32), 8, 1.0, 0.0, 0.0, 0.0, 1.9));
- instanceMap.put("c5.xlarge", new CloudInstance("c5.xlarge", GBtoBytes(8), 4, 0.5, 0.0, 0.0, 0.0, 1.7));
- instanceMap.put("c5.2xlarge", new CloudInstance("c5.2xlarge", GBtoBytes(16), 8, 1.0, 0.0, 0.0, 0.0, 2.1));
+ instanceMap.put("m5.xlarge", new CloudInstance("m5.xlarge", GBtoBytes(16), 4, 0.34375, 21328.0, 143.75, 160.0, 0.23));
+ instanceMap.put("m5.2xlarge", new CloudInstance("m5.2xlarge", GBtoBytes(32), 8, 0.6875, 21328.0, 287.50, 320.0, 0.46));
+ instanceMap.put("c5.xlarge", new CloudInstance("c5.xlarge", GBtoBytes(8), 4, 0.46875, 21328.0, 143.75, 160.0, 0.194));
+ instanceMap.put("c5.2xlarge", new CloudInstance("c5.2xlarge", GBtoBytes(16), 8, 0.9375, 21328.0, 287.50, 320.0, 0.388));
return instanceMap;
}
@@ -60,10 +60,10 @@
List<String> csvLines = Arrays.asList(
"API_Name,Memory,vCPUs,gFlops,ramSpeed,diskSpeed,networkSpeed,Price",
- "m5.xlarge,16.0,4,0.5,0,143.75,160,1.5",
- "m5.2xlarge,32.0,8,1.0,0,0,0,1.9",
- "c5.xlarge,8.0,4,0.5,0,0,0,1.7",
- "c5.2xlarge,16.0,8,1.0,0,0,0,2.1"
+ "m5.xlarge,16.0,4,0.34375,21328.0,143.75,160.0,0.23",
+ "m5.2xlarge,32.0,8,0.6875,21328.0,287.50,320.0,0.46",
+ "c5.xlarge,8.0,4,0.46875,21328.0,143.75,160.0,0.194",
+ "c5.2xlarge,16.0,8,0.9375,21328.0,287.50,320.0,0.388"
);
Files.write(tmpFile.toPath(), csvLines);
return tmpFile;
diff --git a/src/test/scripts/component/resource/Algorithm_MLogreg.dml b/src/test/scripts/component/resource/Algorithm_MLogreg.dml
deleted file mode 100644
index 1cef70e..0000000
--- a/src/test/scripts/component/resource/Algorithm_MLogreg.dml
+++ /dev/null
@@ -1,26 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-X = rand(rows=10000, cols=10);
-Y = X %*% rand(rows=10, cols=1);
-w = multiLogReg(X=X, Y=Y, icpt=2, tol=1e-8, reg=0.01, maxi=20);
-print(sum(w));
-
diff --git a/src/test/scripts/component/resource/Algorithm_KMeans.dml b/src/test/scripts/component/resource/Algorithm_PNMF.dml
similarity index 89%
copy from src/test/scripts/component/resource/Algorithm_KMeans.dml
copy to src/test/scripts/component/resource/Algorithm_PNMF.dml
index 67bc3c6..5737585 100644
--- a/src/test/scripts/component/resource/Algorithm_KMeans.dml
+++ b/src/test/scripts/component/resource/Algorithm_PNMF.dml
@@ -19,7 +19,9 @@
#
#-------------------------------------------------------------
-X = rand(rows=10000, cols=10);
-C = kmeans(X=X, k=4, runs=10, eps=1e-8, max_iter=20);
-print(sum(C));
+X = rand(rows=100000, cols=1000);
+rank = 10;
+
+[w, h] = pnmf(X=X, rnk=rank, verbose=FALSE);
+print(sum(w));
diff --git a/src/test/scripts/component/resource/Algorithm_KMeans.dml b/src/test/scripts/component/resource/ReadAndWrite.dml
similarity index 89%
rename from src/test/scripts/component/resource/Algorithm_KMeans.dml
rename to src/test/scripts/component/resource/ReadAndWrite.dml
index 67bc3c6..98995c3 100644
--- a/src/test/scripts/component/resource/Algorithm_KMeans.dml
+++ b/src/test/scripts/component/resource/ReadAndWrite.dml
@@ -19,7 +19,6 @@
#
#-------------------------------------------------------------
-X = rand(rows=10000, cols=10);
-C = kmeans(X=X, k=4, runs=10, eps=1e-8, max_iter=20);
-print(sum(C));
-
+A = read($fileA);
+write(A, $fileA_Csv, format="csv");
+write(A, $fileA_Text, format="text");