| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops.rewrite; |
| |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.Map.Entry; |
| |
| import org.apache.sysds.common.Types.OpOpData; |
| import org.apache.sysds.hops.DataOp; |
| import org.apache.sysds.hops.Hop; |
| import org.apache.sysds.hops.LiteralOp; |
| import org.apache.sysds.parser.DataExpression; |
| |
| /** |
| * Rule: RemoveReadAfterWrite. If there is a persistent read with the same filename |
| * as a persistent write, and read has a higher line number than the write, |
| * we remove the read and consume the write input directly. This is important for two |
| * reasons (1) correctness and (2) performance. Without this rewrite, we could not |
| * guarantee the order of read-after-write because there is not data dependency |
| * |
| */ |
| public class RewriteRemoveReadAfterWrite extends HopRewriteRule |
| { |
| |
| @Override |
| @SuppressWarnings("unchecked") |
| public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) |
| { |
| if( roots == null ) |
| return null; |
| |
| //collect all persistent reads and writes |
| HashMap<String,Hop> reads = new HashMap<>(); |
| HashMap<String,Hop> writes = new HashMap<>(); |
| for( Hop h : roots ) |
| collectPersistentReadWriteOps( h, writes, reads ); |
| |
| Hop.resetVisitStatus(roots); |
| |
| //check persistent reads for read-after-write pattern |
| for( Entry<String, Hop> e : reads.entrySet() ) |
| { |
| String rfname = e.getKey(); |
| Hop rhop = e.getValue(); |
| if( writes.containsKey(rfname) //same persistent filename |
| && (writes.get(rfname).getBeginLine()<rhop.getBeginLine() //read after write |
| || writes.get(rfname).getEndLine()<rhop.getEndLine()) ) //note: account for bug in line handling, TODO remove after line handling resolved |
| { |
| //rewire read consumers to write input |
| Hop input = writes.get(rfname).getInput().get(0); |
| ArrayList<Hop> parents = (ArrayList<Hop>) rhop.getParent().clone(); |
| for( Hop p : parents ) |
| HopRewriteUtils.replaceChildReference(p, rhop, input); |
| } |
| } |
| |
| return roots; |
| } |
| |
| @Override |
| public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { |
| //do noting, read/write do not occur in predicates |
| return root; |
| } |
| |
| private void collectPersistentReadWriteOps(Hop hop, HashMap<String,Hop> pWrites, HashMap<String,Hop> pReads) { |
| if( hop.isVisited() ) |
| return; |
| |
| //process childs |
| if( !hop.getInput().isEmpty() ) |
| for( Hop c : hop.getInput() ) |
| collectPersistentReadWriteOps(c, pWrites, pReads); |
| |
| //process current hop |
| if( hop instanceof DataOp ) |
| { |
| DataOp dop = (DataOp)hop; |
| if( dop.getOp()==OpOpData.PERSISTENTREAD ) |
| pReads.put(dop.getFileName(), dop); |
| else if( dop.getOp()==OpOpData.PERSISTENTWRITE ) |
| { |
| Hop fname = dop.getInput().get(dop.getParameterIndex(DataExpression.IO_FILENAME)); |
| if( fname instanceof LiteralOp ) //only constant writes |
| pWrites.put(((LiteralOp) fname).getStringValue(), dop); |
| } |
| } |
| |
| hop.setVisited(); |
| } |
| } |