blob: 2cc96d5d05ee2430fcca2b77c012a265d84edb88 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IndexedIdentifier;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ValueType;
/**
* Rule: Insert checkpointing operations for caching purposes. Currently, we
* follow a heuristic of checkpointing (1) all variables used read-only in loops,
* and (2) intermediates used by multiple consumers.
*
* TODO (2) implement injection for multiple consumers (local and global).
*
*/
public class RewriteInjectSparkLoopCheckpointing extends StatementBlockRewriteRule
{
private boolean _checkCtx = false;
public RewriteInjectSparkLoopCheckpointing(boolean checkParForContext) {
_checkCtx = checkParForContext;
}
@Override
public boolean createsSplitDag() {
return true;
}
@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status)
{
if( !OptimizerUtils.isSparkExecutionMode() ) {
// nothing to do here, return original statement block
return Arrays.asList(sb);
}
//1) We currently add checkpoint operations without information about the global program structure,
//this assumes that redundant checkpointing is prevented at runtime level (instruction-level)
//2) Also, we do not take size information into account right now. This means that all candidates
//are checkpointed even if they are only used by CP operations.
ArrayList<StatementBlock> ret = new ArrayList<>();
int blocksize = status.getBlocksize(); //block size set by reblock rewrite
//apply rewrite for while, for, and parfor (the decision for parfor loop bodies is deferred until parfor
//optimization because otherwise we would prevent remote parfor)
if( (sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) //incl parfor
&& (_checkCtx ? !status.isInParforContext() : true) )
{
//step 1: determine checkpointing candidates
ArrayList<String> candidates = new ArrayList<>();
VariableSet read = sb.variablesRead();
VariableSet updated = sb.variablesUpdated();
for( String rvar : read.getVariableNames() )
if( !updated.containsVariable(rvar) && (read.getVariable(rvar).getDataType()==DataType.MATRIX ||
read.getVariable(rvar).getDataType()==DataType.TENSOR))
candidates.add(rvar);
//step 2: insert statement block with checkpointing operations
if( !candidates.isEmpty() ) //existing candidates
{
StatementBlock sb0 = new StatementBlock();
sb0.setDMLProg(sb.getDMLProg());
sb0.setParseInfo(sb);
ArrayList<Hop> hops = new ArrayList<>();
VariableSet livein = new VariableSet();
VariableSet liveout = new VariableSet();
for( String var : candidates )
{
DataIdentifier dat = read.getVariable(var);
long dim1 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier)dat).getOrigDim1() : dat.getDim1();
long dim2 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier)dat).getOrigDim2() : dat.getDim2();
DataOp tread = new DataOp(var, DataType.MATRIX, ValueType.FP64, OpOpData.TRANSIENTREAD,
dat.getFilename(), dim1, dim2, dat.getNnz(), blocksize);
tread.setRequiresCheckpoint(true);
DataOp twrite = HopRewriteUtils.createTransientWrite(var, tread);
hops.add(twrite);
livein.addVariable(var, read.getVariable(var));
liveout.addVariable(var, read.getVariable(var));
}
sb0.setHops(hops);
sb0.setLiveIn(livein);
sb0.setLiveOut(liveout);
sb0.setSplitDag(true);
ret.add(sb0);
//maintain rewrite status
status.setInjectedCheckpoints();
}
}
//add original statement block to end
ret.add(sb);
return ret;
}
@Override
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) {
return sbs;
}
}