blob: c18df0181f6cae60a7c66d3cd1de175d5b9f5593 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map.Entry;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.parser.DataExpression;
/**
* Rule: RemoveReadAfterWrite. If there is a persistent read with the same filename
* as a persistent write, and read has a higher line number than the write,
* we remove the read and consume the write input directly. This is important for two
* reasons (1) correctness and (2) performance. Without this rewrite, we could not
* guarantee the order of read-after-write because there is not data dependency
*
*/
public class RewriteRemoveReadAfterWrite extends HopRewriteRule
{
@Override
@SuppressWarnings("unchecked")
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state)
{
if( roots == null )
return null;
//collect all persistent reads and writes
HashMap<String,Hop> reads = new HashMap<>();
HashMap<String,Hop> writes = new HashMap<>();
for( Hop h : roots )
collectPersistentReadWriteOps( h, writes, reads );
Hop.resetVisitStatus(roots);
//check persistent reads for read-after-write pattern
for( Entry<String, Hop> e : reads.entrySet() )
{
String rfname = e.getKey();
Hop rhop = e.getValue();
if( writes.containsKey(rfname) //same persistent filename
&& (writes.get(rfname).getBeginLine()<rhop.getBeginLine() //read after write
|| writes.get(rfname).getEndLine()<rhop.getEndLine()) ) //note: account for bug in line handling, TODO remove after line handling resolved
{
//rewire read consumers to write input
Hop input = writes.get(rfname).getInput().get(0);
ArrayList<Hop> parents = (ArrayList<Hop>) rhop.getParent().clone();
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, rhop, input);
}
}
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
//do noting, read/write do not occur in predicates
return root;
}
private void collectPersistentReadWriteOps(Hop hop, HashMap<String,Hop> pWrites, HashMap<String,Hop> pReads) {
if( hop.isVisited() )
return;
//process childs
if( !hop.getInput().isEmpty() )
for( Hop c : hop.getInput() )
collectPersistentReadWriteOps(c, pWrites, pReads);
//process current hop
if( hop instanceof DataOp )
{
DataOp dop = (DataOp)hop;
if( dop.getOp()==OpOpData.PERSISTENTREAD )
pReads.put(dop.getFileName(), dop);
else if( dop.getOp()==OpOpData.PERSISTENTWRITE )
{
Hop fname = dop.getInput().get(dop.getParameterIndex(DataExpression.IO_FILENAME));
if( fname instanceof LiteralOp ) //only constant writes
pWrites.put(((LiteralOp) fname).getStringValue(), dop);
}
}
hop.setVisited();
}
}