blob: fffc2dcbce49d9af4c468aff8bd27602ebfb4cd3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.lineage;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.FunctionOp.FunctionType;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionParser;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.spark.RandSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType;
import org.apache.sysds.utils.Explain;
import org.apache.sysds.utils.Explain.ExplainCounts;
import org.apache.sysds.utils.Statistics;
public class LineageRecomputeUtils {
private static final String LVARPREFIX = "lvar";
public static final String LPLACEHOLDER = "IN#";
private static final boolean DEBUG = false;
public static Map<String, DedupLoopItem> loopPatchMap = new HashMap<>();
public static Data parseNComputeLineageTrace(String mainTrace, String dedupPatches) {
LineageItem root = LineageParser.parseLineageTrace(mainTrace);
if (dedupPatches != null)
LineageParser.parseLineageTraceDedup(dedupPatches);
Data ret = computeByLineage(root);
// Cleanup the statics
loopPatchMap.clear();
return ret;
}
private static Data computeByLineage(LineageItem root)
{
long rootId = root.getOpcode().equals("write") ?
root.getInputs()[0].getId() : root.getId();
String varname = LVARPREFIX + rootId;
Program prog = new Program(null);
// Recursively construct hops
root.resetVisitStatusNR();
Map<Long, Hop> operands = new HashMap<>();
Map<String, Hop> partDagRoots = new HashMap<>();
rConstructHops(root, operands, partDagRoots, prog);
Hop out = HopRewriteUtils.createTransientWrite(
varname, operands.get(rootId));
// Generate instructions
ExecutionContext ec = ExecutionContextFactory.createContext();
partDagRoots.put(varname, out);
constructBasicBlock(partDagRoots, varname, prog);
// Reset cache due to cleaned data objects
LineageCache.resetCache();
//execute instructions and get result
if (DEBUG) {
DMLScript.STATISTICS = true;
ExplainCounts counts = Explain.countDistributedOperations(prog);
System.out.println(Explain.display(null, prog, Explain.ExplainType.RUNTIME, counts));
}
ec.setProgram(prog);
prog.execute(ec);
if (DEBUG) {
Statistics.stopRunTimer();
System.out.println(Statistics.display(DMLScript.STATISTICS_COUNT));
}
return ec.getVariable(varname);
}
private static void constructBasicBlock(Map<String, Hop> partDagRoots, String dedupOut, Program prog) {
Hop out = partDagRoots.get(dedupOut);
// Compile and save
BasicProgramBlock pb = new BasicProgramBlock(prog);
pb.setInstructions(genInst(out));
prog.addProgramBlock(pb);
}
private static void rConstructHops(LineageItem item, Map<Long, Hop> operands, Map<String, Hop> partDagRoots, Program prog)
{
if (item.isVisited())
return;
//recursively process children (ordering by data dependencies)
if (!item.isLeaf())
for (LineageItem c : item.getInputs())
rConstructHops(c, operands, partDagRoots, prog);
//process current lineage item
//NOTE: we generate instructions from hops (but without rewrites) to automatically
//handle execution types, rmvar instructions, and rewiring of inputs/outputs
switch (item.getType()) {
case Creation: {
if (item.getData().startsWith(LPLACEHOLDER)) {
long phId = Long.parseLong(item.getData().substring(3));
Hop input = operands.get(phId);
operands.remove(phId);
// Replace the placeholders with TReads
operands.put(item.getId(), input); // order preserving
break;
}
Instruction inst = InstructionParser.parseSingleInstruction(item.getData());
if (inst instanceof DataGenCPInstruction) {
DataGenCPInstruction rand = (DataGenCPInstruction) inst;
HashMap<String, Hop> params = new HashMap<>();
if( rand.getOpcode().equals("rand") ) {
if( rand.output.getDataType() == DataType.TENSOR)
params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
else {
params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
}
params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
}
else if( rand.getOpcode().equals("seq") ) {
params.put(Statement.SEQ_FROM, new LiteralOp(rand.getFrom()));
params.put(Statement.SEQ_TO, new LiteralOp(rand.getTo()));
params.put(Statement.SEQ_INCR, new LiteralOp(rand.getIncr()));
}
Hop datagen = new DataGenOp(OpOpDG.valueOf(rand.getOpcode().toUpperCase()),
new DataIdentifier("tmp"), params);
datagen.setBlocksize(rand.getBlocksize());
operands.put(item.getId(), datagen);
} else if (inst instanceof VariableCPInstruction
&& ((VariableCPInstruction) inst).isCreateVariable()) {
String parts[] = InstructionUtils.getInstructionPartsWithValueType(inst.toString());
DataType dt = DataType.valueOf(parts[4]);
ValueType vt = dt == DataType.MATRIX ? ValueType.FP64 : ValueType.STRING;
HashMap<String, Hop> params = new HashMap<>();
params.put(DataExpression.IO_FILENAME, new LiteralOp(parts[2]));
params.put(DataExpression.READROWPARAM, new LiteralOp(Long.parseLong(parts[6])));
params.put(DataExpression.READCOLPARAM, new LiteralOp(Long.parseLong(parts[7])));
params.put(DataExpression.READNNZPARAM, new LiteralOp(Long.parseLong(parts[8])));
params.put(DataExpression.FORMAT_TYPE, new LiteralOp(parts[5]));
DataOp pread = new DataOp(parts[1].substring(5), dt, vt, OpOpData.PERSISTENTREAD, params);
pread.setFileName(parts[2]);
operands.put(item.getId(), pread);
}
else if (inst instanceof RandSPInstruction) {
RandSPInstruction rand = (RandSPInstruction) inst;
HashMap<String, Hop> params = new HashMap<>();
if (rand.output.getDataType() == DataType.TENSOR)
params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
else {
params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
}
params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
Hop datagen = new DataGenOp(OpOpDG.RAND, new DataIdentifier("tmp"), params);
datagen.setBlocksize(rand.getBlocksize());
operands.put(item.getId(), datagen);
}
break;
}
case Dedup: {
// Create function call for each dedup entry
String[] parts = item.getOpcode().split(LineageDedupUtils.DEDUP_DELIM); //e.g. dedup_R_SB13_0
String name = parts[2] + parts[1] + parts[3]; //loopId + outVar + pathId
List<Hop> finputs = Arrays.stream(item.getInputs())
.map(inp -> operands.get(inp.getId())).collect(Collectors.toList());
String[] inputNames = new String[item.getInputs().length];
for (int i=0; i<item.getInputs().length; i++)
inputNames[i] = LPLACEHOLDER + i; //e.g. IN#0, IN#1
Hop funcOp = new FunctionOp(FunctionType.DML, DMLProgram.DEFAULT_NAMESPACE,
name, inputNames, finputs, new String[] {parts[1]}, false);
// Cut the Hop dag after function calls
partDagRoots.put(parts[1], funcOp);
// Compile the dag and save
constructBasicBlock(partDagRoots, parts[1], prog);
// Construct a Hop dag for the function body from the dedup patch, and compile
Hop output = constructHopsDedupPatch(parts, inputNames, finputs, prog);
// Create a TRead on the function o/p as a leaf for the next Hop dag
// Use the function body root/return hop to propagate right data type
operands.put(item.getId(), HopRewriteUtils.createTransientRead(parts[1], output));
break;
}
case Instruction: {
CPType ctype = InstructionUtils.getCPTypeByOpcode(item.getOpcode());
SPType stype = InstructionUtils.getSPTypeByOpcode(item.getOpcode());
if (ctype != null) {
switch (ctype) {
case AggregateUnary: {
Hop input = operands.get(item.getInputs()[0].getId());
Hop aggunary = InstructionUtils.isUnaryMetadata(item.getOpcode()) ?
HopRewriteUtils.createUnary(input, OpOp1.valueOfByOpcode(item.getOpcode())) :
HopRewriteUtils.createAggUnaryOp(input, item.getOpcode());
operands.put(item.getId(), aggunary);
break;
}
case AggregateBinary: {
Hop input1 = operands.get(item.getInputs()[0].getId());
Hop input2 = operands.get(item.getInputs()[1].getId());
Hop aggbinary = HopRewriteUtils.createMatrixMultiply(input1, input2);
operands.put(item.getId(), aggbinary);
break;
}
case AggregateTernary: {
Hop input1 = operands.get(item.getInputs()[0].getId());
Hop input2 = operands.get(item.getInputs()[1].getId());
Hop input3 = operands.get(item.getInputs()[2].getId());
Hop aggternary = HopRewriteUtils.createSum(
HopRewriteUtils.createBinary(
HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT),
input3, OpOp2.MULT));
operands.put(item.getId(), aggternary);
break;
}
case Unary:
case Builtin: {
Hop input = operands.get(item.getInputs()[0].getId());
Hop unary = HopRewriteUtils.createUnary(input, item.getOpcode());
operands.put(item.getId(), unary);
break;
}
case Reorg: {
operands.put(item.getId(), HopRewriteUtils.createReorg(
operands.get(item.getInputs()[0].getId()), item.getOpcode()));
break;
}
case Reshape: {
ArrayList<Hop> inputs = new ArrayList<>();
for(int i=0; i<5; i++)
inputs.add(operands.get(item.getInputs()[i].getId()));
operands.put(item.getId(), HopRewriteUtils.createReorg(inputs, ReOrgOp.RESHAPE));
break;
}
case Binary: {
//handle special cases of binary operations
String opcode = ("^2".equals(item.getOpcode())
|| "*2".equals(item.getOpcode())) ?
item.getOpcode().substring(0, 1) : item.getOpcode();
Hop input1 = operands.get(item.getInputs()[0].getId());
Hop input2 = operands.get(item.getInputs()[1].getId());
Hop binary = HopRewriteUtils.createBinary(input1, input2, opcode);
operands.put(item.getId(), binary);
break;
}
case Ternary: {
operands.put(item.getId(), HopRewriteUtils.createTernary(
operands.get(item.getInputs()[0].getId()),
operands.get(item.getInputs()[1].getId()),
operands.get(item.getInputs()[2].getId()), item.getOpcode()));
break;
}
case Ctable: { //e.g., ctable
if( item.getInputs().length==3 )
operands.put(item.getId(), HopRewriteUtils.createTernary(
operands.get(item.getInputs()[0].getId()),
operands.get(item.getInputs()[1].getId()),
operands.get(item.getInputs()[2].getId()), OpOp3.CTABLE));
else if( item.getInputs().length==5 )
operands.put(item.getId(), HopRewriteUtils.createTernary(
operands.get(item.getInputs()[0].getId()),
operands.get(item.getInputs()[1].getId()),
operands.get(item.getInputs()[2].getId()),
operands.get(item.getInputs()[3].getId()),
operands.get(item.getInputs()[4].getId()), OpOp3.CTABLE));
break;
}
case BuiltinNary: {
String opcode = item.getOpcode().equals("n+") ? "plus" : item.getOpcode();
operands.put(item.getId(), HopRewriteUtils.createNary(
OpOpN.valueOf(opcode.toUpperCase()), createNaryInputs(item, operands)));
break;
}
case ParameterizedBuiltin: {
operands.put(item.getId(), constructParameterizedBuiltinOp(item, operands));
break;
}
case MatrixIndexing: {
operands.put(item.getId(), constructIndexingOp(item, operands));
break;
}
case MMTSJ: {
//TODO handling of tsmm type left and right -> placement transpose
Hop input = operands.get(item.getInputs()[0].getId());
Hop aggunary = HopRewriteUtils.createMatrixMultiply(
HopRewriteUtils.createTranspose(input), input);
operands.put(item.getId(), aggunary);
break;
}
case Variable: {
if( item.getOpcode().startsWith("cast") )
operands.put(item.getId(), HopRewriteUtils.createUnary(
operands.get(item.getInputs()[0].getId()),
OpOp1.valueOfByOpcode(item.getOpcode())));
else //cpvar, write
operands.put(item.getId(), operands.get(item.getInputs()[0].getId()));
break;
}
default:
throw new DMLRuntimeException("Unsupported instruction "
+ "type: " + ctype.name() + " (" + item.getOpcode() + ").");
}
}
else if( stype != null ) {
switch(stype) {
case Reblock: {
Hop input = operands.get(item.getInputs()[0].getId());
input.setBlocksize(ConfigurationManager.getBlocksize());
input.setRequiresReblock(true);
operands.put(item.getId(), input);
break;
}
case Checkpoint: {
Hop input = operands.get(item.getInputs()[0].getId());
operands.put(item.getId(), input);
break;
}
case MatrixIndexing: {
operands.put(item.getId(), constructIndexingOp(item, operands));
break;
}
case GAppend: {
operands.put(item.getId(), HopRewriteUtils.createBinary(
operands.get(item.getInputs()[0].getId()),
operands.get(item.getInputs()[1].getId()), OpOp2.CBIND));
break;
}
default:
throw new DMLRuntimeException("Unsupported instruction "
+ "type: " + stype.name() + " (" + item.getOpcode() + ").");
}
}
else
throw new DMLRuntimeException("Unsupported instruction: " + item.getOpcode());
break;
}
case Literal: {
CPOperand op = new CPOperand(item.getData());
operands.put(item.getId(), ScalarObjectFactory
.createLiteralOp(op.getValueType(), op.getName()));
break;
}
}
item.setVisited();
}
// Construct and compile the function body
private static Hop constructHopsDedupPatch(String[] parts, String[] inputs, List<Hop> inpHops, Program prog) {
String outname = parts[1];
Long pathId = Long.parseLong(parts[3]);
DedupLoopItem loop = loopPatchMap.get(parts[2]);
// Return if this patch is already compiled
if (loop.patchHopMap.containsKey(pathId) && loop.patchHopMap.get(pathId).containsKey(outname))
return loop.patchHopMap.get(pathId).get(outname);
// Construct a Hop dag
LineageItem patchRoot = loop.patchLiMap.get(pathId).get(outname);
patchRoot.resetVisitStatusNR();
Map<Long, Hop> operands = new HashMap<>();
// Create TRead on the function inputs
//FIXME: the keys of operands can be replaced inside rConstructHops
for (int i=0; i<inputs.length; i++)
operands.put((long)i, HopRewriteUtils.createTransientRead(inputs[i], inpHops.get(i))); //order preserving
// Construct the Hop dag.
rConstructHops(patchRoot, operands, null, null);
// TWrite the func return (pass dag root to copy datatype)
Hop out = HopRewriteUtils.createTransientWrite(outname, operands.get(patchRoot.getId()));
// Save the Hop dag
if (!loop.patchHopMap.containsKey(pathId))
loop.patchHopMap.put(pathId, new HashMap<>());
loop.patchHopMap.get(pathId).put(outname, out);
// Compile to instructions and save as a FunctionProgramBlock
List<DataIdentifier> funcInputs = new ArrayList<>();
for (int i=0; i<inpHops.size(); i++)
funcInputs.add(new DataIdentifier(inputs[i], inpHops.get(i).getDataType(), inpHops.get(i).getValueType()));
List<DataIdentifier> funcOutput = new ArrayList<>(Arrays.asList(new DataIdentifier(outname)));
// TODO: multi-return function
FunctionProgramBlock fpb = new FunctionProgramBlock(prog, funcInputs, funcOutput);
BasicProgramBlock pb = new BasicProgramBlock(prog);
pb.setInstructions(genInst(out));
fpb.addProgramBlock(pb);
prog.addFunctionProgramBlock(DMLProgram.DEFAULT_NAMESPACE, parts[2]+parts[1]+parts[3], fpb);
//fpb.setRecompileOnce(true);
return out;
}
private static ArrayList<Instruction> genInst (Hop root) {
Dag<Lop> dag = new Dag<>();
Lop lops = root.constructLops();
lops.addToDag(dag);
return dag.getJobs(null, ConfigurationManager.getDMLConfig());
}
private static Hop[] createNaryInputs(LineageItem item, Map<Long, Hop> operands) {
int len = item.getInputs().length;
Hop[] ret = new Hop[len];
for( int i=0; i<len; i++ )
ret[i] = operands.get(item.getInputs()[i].getId());
return ret;
}
private static Hop constructParameterizedBuiltinOp(LineageItem item, Map<Long, Hop> operands) {
String opcode = item.getOpcode();
Hop target = operands.get(item.getInputs()[0].getId());
LinkedHashMap<String,Hop> args = new LinkedHashMap<>();
if( opcode.equals("groupedagg") ) {
args.put("target", target);
args.put(Statement.GAGG_GROUPS, operands.get(item.getInputs()[1].getId()));
args.put(Statement.GAGG_WEIGHTS, operands.get(item.getInputs()[2].getId()));
args.put(Statement.GAGG_FN, operands.get(item.getInputs()[3].getId()));
args.put(Statement.GAGG_NUM_GROUPS, operands.get(item.getInputs()[4].getId()));
}
else if (opcode.equalsIgnoreCase("rmempty")) {
args.put("target", target);
args.put("margin", operands.get(item.getInputs()[1].getId()));
args.put("select", operands.get(item.getInputs()[2].getId()));
}
else if(opcode.equalsIgnoreCase("replace")) {
args.put("target", target);
args.put("pattern", operands.get(item.getInputs()[1].getId()));
args.put("replacement", operands.get(item.getInputs()[2].getId()));
}
else if(opcode.equalsIgnoreCase("rexpand")) {
args.put("target", target);
args.put("max", operands.get(item.getInputs()[1].getId()));
args.put("dir", operands.get(item.getInputs()[2].getId()));
args.put("cast", operands.get(item.getInputs()[3].getId()));
args.put("ignore", operands.get(item.getInputs()[4].getId()));
}
return HopRewriteUtils.createParameterizedBuiltinOp(
target, args, ParamBuiltinOp.valueOf(opcode.toUpperCase()));
}
private static Hop constructIndexingOp(LineageItem item, Map<Long, Hop> operands) {
Hop input = operands.get(item.getInputs()[0].getId());
if( "rightIndex".equals(item.getOpcode()) )
return HopRewriteUtils.createIndexingOp(input,
operands.get(item.getInputs()[1].getId()), //rl
operands.get(item.getInputs()[2].getId()), //ru
operands.get(item.getInputs()[3].getId()), //cl
operands.get(item.getInputs()[4].getId())); //cu
else if( "leftIndex".equals(item.getOpcode())
|| "mapLeftIndex".equals(item.getOpcode()) )
return HopRewriteUtils.createLeftIndexingOp(input,
operands.get(item.getInputs()[1].getId()), //rhs
operands.get(item.getInputs()[2].getId()), //rl
operands.get(item.getInputs()[3].getId()), //ru
operands.get(item.getInputs()[4].getId()), //cl
operands.get(item.getInputs()[5].getId())); //cu
throw new DMLRuntimeException("Unsupported opcode: "+item.getOpcode());
}
// Below class represents a single loop and contains related data
// that are needed for recomputation.
protected static class DedupLoopItem {
public String functionName;
// Lineage/Hop DAG per output variable per unique path
public final Map<Long, Map<String, LineageItem>> patchLiMap = new HashMap<>();
private final Map<Long, Map<String, Hop>> patchHopMap = new HashMap<>();
public DedupLoopItem(String name) {
functionName = name;
}
}
}