blob: 1558beb34424f83c1c8b0ef68978722029adb6ef [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.lineage;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.EvalNaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem.LineageItemType;
import org.apache.sysds.utils.Explain;
import java.util.HashMap;
import java.util.Map;
public class LineageMap {
private final Map<String, LineageItem> _traces;
private final Map<String, LineageItem> _literals;
public LineageMap() {
_traces = new HashMap<>();
_literals = new HashMap<>();
}
public LineageMap(LineageMap that) {
this();
_traces.putAll(that._traces);
_literals.putAll(that._literals);
}
public void trace(Instruction inst, ExecutionContext ec) {
if( inst instanceof FunctionCallCPInstruction || inst instanceof EvalNaryCPInstruction)
return; // no need for lineage tracing
if (!(inst instanceof LineageTraceable))
throw new DMLRuntimeException("Unknown Instruction (" + inst.getOpcode() + ") traced.");
LineageTraceable linst = (LineageTraceable) inst;
if( linst.hasSingleLineage() ) {
trace(inst, ec, linst.getLineageItem(ec));
}
else {
Pair<String, LineageItem>[] items = linst.getLineageItems(ec);
if (items == null || items.length < 1)
trace(inst, ec, null);
else {
for (Pair<String, LineageItem> li : items)
trace(inst, ec, cleanupInputLiterals(li, ec));
}
}
}
public void processDedupItem(LineageMap lm, Long path, LineageItem[] liinputs, String name) {
String delim = LineageDedupUtils.DEDUP_DELIM;
for (Map.Entry<String, LineageItem> entry : lm._traces.entrySet()) {
// Encode everything in the opcode needed by the deserialization logic
// to map this lineage item to the right patch.
String opcode = LineageItem.dedupItemOpcode + delim + entry.getKey()
+ delim + name + delim + path.toString();
LineageItem li = new LineageItem(opcode, liinputs);
addLineageItem(Pair.of(entry.getKey(), li));
}
}
public LineageItem getOrCreate(CPOperand variable) {
if (variable == null)
return null;
String varname = variable.getName();
//handle literals (never in traces)
if (variable.isLiteral()) {
LineageItem ret = _literals.get(varname);
if (ret == null)
_literals.put(varname, ret = new LineageItem(variable.getLineageLiteral()));
return ret;
}
//handle variables
LineageItem ret = _traces.get(variable.getName());
return (ret != null) ? ret :
new LineageItem(variable.getLineageLiteral());
}
public LineageItem get(String varName) {
return _traces.get(varName);
}
public LineageItem set(String varName, LineageItem li) {
return _traces.put(varName, li);
}
public LineageItem setLiteral(String varName, LineageItem li) {
return _literals.put(varName, li);
}
public LineageItem get(CPOperand variable) {
if (variable == null)
return null;
return _traces.get(variable.getName());
}
public boolean contains(CPOperand variable) {
return _traces.containsKey(variable.getName());
}
public boolean containsKey(String key) {
return _traces.containsKey(key);
}
public void resetLineageMaps() {
_traces.clear();
_literals.clear();
}
public Map<String, LineageItem> getTraces() {
return _traces;
}
public Map<String, LineageItem> getLiterals() {
return _literals;
}
private void trace(Instruction inst, ExecutionContext ec, Pair<String, LineageItem> li) {
if (inst instanceof VariableCPInstruction) {
VariableCPInstruction vcp_inst = ((VariableCPInstruction) inst);
switch (vcp_inst.getVariableOpcode()) {
case AssignVariable:
case CopyVariable: {
processCopyLI(li);
break;
}
case Read:
case CreateVariable: {
if (li != null)
addLineageItem(li);
break;
}
case RemoveVariable: {
for (CPOperand input : vcp_inst.getInputs())
removeLineageItem(input.getName());
break;
}
case Write: {
processWriteLI(vcp_inst.getInput1(), vcp_inst.getInput2(), ec);
break;
}
case MoveVariable: {
moveLineageItem(vcp_inst.getInput1().getName(), vcp_inst.getInput2().getName());
break;
}
case CastAsBooleanVariable:
case CastAsDoubleVariable:
case CastAsIntegerVariable:
case CastAsScalarVariable:
case CastAsMatrixVariable:
case CastAsFrameVariable: {
addLineageItem(li);
break;
}
default:
throw new DMLRuntimeException("Unknown VariableCPInstruction (" + inst.getOpcode() + ") traced.");
}
}
else if (inst instanceof WriteSPInstruction){
processWriteLI(((WriteSPInstruction) inst).getInput1(), ((WriteSPInstruction) inst).getInput2(), ec);
}
else
addLineageItem(li);
}
private Pair<String, LineageItem> cleanupInputLiterals(Pair<String, LineageItem> li, ExecutionContext ec) {
LineageItem item = li.getValue();
if( item.getInputs() == null )
return li;
// fix literals referring to variables (e.g., for/parfor loop variable)
for(int i=0; i<item.getInputs().length; i++) {
LineageItem tmp = item.getInputs()[i];
if( tmp.getType() != LineageItemType.Literal)
continue;
//check if CPOperand is not a literal, w/o parsing
if( tmp.getData().endsWith("false") ) {
CPOperand cp = new CPOperand(tmp.getData());
if( cp.getDataType().isScalar() ) {
cp.setLiteral(ec.getScalarInput(cp));
item.getInputs()[i] = getOrCreate(cp);
}
}
}
return li;
}
private void processCopyLI(Pair<String, LineageItem> li) {
if (li.getValue().getInputs().length != 1)
throw new DMLRuntimeException("AssignVariable and CopyVariable must have one input lineage item!");
//add item or overwrite existing item
_traces.put(li.getKey(), li.getValue().getInputs()[0]);
}
private void moveLineageItem(String keyFrom, String keyTo) {
LineageItem input = removeLineageItem(keyFrom);
if (!keyTo.equals("__pred"))
_traces.put(keyTo, input);
}
private LineageItem removeLineageItem(String key) {
//remove item if present
return _traces.remove(key);
}
private void addLineageItem(Pair<String, LineageItem> li) {
//add item or overwrite existing item
_traces.put(li.getKey(), li.getValue());
}
private void processWriteLI(CPOperand input1, CPOperand input2, ExecutionContext ec) {
LineageItem li = get(input1);
String fName = ec.getScalarInput(input2.getName(), Types.ValueType.STRING, input2.isLiteral()).getStringValue();
if (DMLScript.LINEAGE_DEDUP) {
// gracefully serialize the dedup maps without decompressing
LineageItemUtils.writeTraceToHDFS(LineageDedupUtils.mergeExplainDedupBlocks(ec), fName + ".lineage.dedup");
}
LineageItemUtils.writeTraceToHDFS(Explain.explain(li), fName + ".lineage");
}
}