blob: 55a6f151a028b17036b0817a5729b4fa19be91af [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
/**
* Rule: Insert checkpointing operations for caching purposes. Currently, we
* follow a heuristic of checkpointing (1) all variables used read-only in loops,
* and (2) intermediates used by multiple consumers.
*
* TODO (2) implement injection for multiple consumers (local and global).
*
*/
public class RewriteInjectSparkLoopCheckpointing extends StatementBlockRewriteRule
{
private boolean _checkCtx = false;
public RewriteInjectSparkLoopCheckpointing(boolean checkParForContext)
{
_checkCtx = checkParForContext;
}
@Override
public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status)
throws HopsException
{
ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
if( !OptimizerUtils.isSparkExecutionMode() )
{
ret.add(sb); // nothing to do here
return ret; //return original statement block
}
//1) We currently add checkpoint operations without information about the global program structure,
//this assumes that redundant checkpointing is prevented at runtime level (instruction-level)
//2) Also, we do not take size information into account right now. This means that all candidates
//are checkpointed even if they are only used by CP operations.
int blocksize = status.getBlocksize(); //block size set by reblock rewrite
//apply rewrite for while, for, and parfor (the decision for parfor loop bodies is deferred until parfor
//optimization because otherwise we would prevent remote parfor)
if( (sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) //incl parfor
&& (_checkCtx ? !status.isInParforContext() : true) )
{
//step 1: determine checkpointing candidates
ArrayList<String> candidates = new ArrayList<String>();
VariableSet read = sb.variablesRead();
VariableSet updated = sb.variablesUpdated();
for( String rvar : read.getVariableNames() )
if( !updated.containsVariable(rvar) && read.getVariable(rvar).getDataType()==DataType.MATRIX )
candidates.add(rvar);
//step 2: insert statement block with checkpointing operations
if( !candidates.isEmpty() ) //existing candidates
{
StatementBlock sb0 = new StatementBlock();
sb0.setDMLProg(sb.getDMLProg());
sb0.setAllPositions(sb.getFilename(), sb.getBeginLine(), sb.getBeginColumn(), sb.getEndLine(), sb.getEndColumn());
ArrayList<Hop> hops = new ArrayList<Hop>();
VariableSet livein = new VariableSet();
VariableSet liveout = new VariableSet();
for( String var : candidates )
{
DataIdentifier dat = read.getVariable(var);
DataOp tread = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, DataOpTypes.TRANSIENTREAD,
dat.getFilename(), dat.getDim1(), dat.getDim2(), dat.getNnz(), blocksize, blocksize);
tread.setRequiresCheckpoint( true );
DataOp twrite = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, tread, DataOpTypes.TRANSIENTWRITE, null);
HopRewriteUtils.setOutputParameters(twrite, dat.getDim1(), dat.getDim2(), blocksize, blocksize, dat.getNnz());
hops.add(twrite);
livein.addVariable(var, read.getVariable(var));
liveout.addVariable(var, read.getVariable(var));
}
sb0.set_hops(hops);
sb0.setLiveIn(livein);
sb0.setLiveOut(liveout);
ret.add(sb0);
//maintain rewrite status
status.setInjectedCheckpoints();
}
}
//add original statement block to end
ret.add(sb);
return ret;
}
}