/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatementBlock;

/**
 * This rewrite identifies and removes unnecessary checkpoints, i.e.,
 * persisting of Spark RDDs into a given storage level. For example,
 * in chains such as pread-checkpoint-append-checkpoint, the first
 * checkpoint is not used and creates unnecessary memory pressure.
 * 
 */
public class IPAPassRemoveUnnecessaryCheckpoints extends IPAPass
{
	@Override
	public boolean isApplicable(FunctionCallGraph fgraph) {
		return InterProceduralAnalysis.REMOVE_UNNECESSARY_CHECKPOINTS 
			&& OptimizerUtils.isSparkExecutionMode();
	}
	
	@Override
	public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
		//remove unnecessary checkpoint before update 
		removeCheckpointBeforeUpdate(prog);
		
		//move necessary checkpoint after update
		moveCheckpointAfterUpdate(prog);
		
		//remove unnecessary checkpoint read-{write|uagg}
		removeCheckpointReadWrite(prog);
	}
	
	private static void removeCheckpointBeforeUpdate(DMLProgram dmlp) {
		//approach: scan over top-level program (guaranteed to be unconditional),
		//collect checkpoints; determine if used before update; remove first checkpoint
		//on second checkpoint if update in between and not used before update
		
		HashMap<String, Hop> chkpointCand = new HashMap<>();
		
		for( StatementBlock sb : dmlp.getStatementBlocks() ) 
		{
			//prune candidates (used before updated)
			Set<String> cands = new HashSet<>(chkpointCand.keySet());
			for( String cand : cands )
				if( sb.variablesRead().containsVariable(cand) 
					&& !sb.variablesUpdated().containsVariable(cand) ) 
				{	
					//note: variableRead might include false positives due to meta 
					//data operations like nrow(X) or operations removed by rewrites 
					//double check hops on basic blocks; otherwise worst-case
					boolean skipRemove = false;
					if( sb.getHops() !=null ) {
						Hop.resetVisitStatus(sb.getHops());
						skipRemove = true;
						for( Hop root : sb.getHops() )
							skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
					}					
					if( !skipRemove )
						chkpointCand.remove(cand);
				}
			
			//prune candidates (updated in conditional control flow)
			Set<String> cands2 = new HashSet<>(chkpointCand.keySet());
			if( sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock 
				|| sb instanceof ForStatementBlock )
			{
				for( String cand : cands2 )
					if( sb.variablesUpdated().containsVariable(cand) ) {
						chkpointCand.remove(cand);
					}
			}
			//prune candidates (updated w/ multiple reads) 
			else
			{
				for( String cand : cands2 )
					if( sb.variablesUpdated().containsVariable(cand) && sb.getHops() != null) 
					{
						Hop.resetVisitStatus(sb.getHops());
						for( Hop root : sb.getHops() )
							if( root.getName().equals(cand) &&
								!HopRewriteUtils.rHasSimpleReadChain(root, cand) ) {
								chkpointCand.remove(cand);
							}
					}	
			}
		
			//collect checkpoints and remove unnecessary checkpoints
			if( HopRewriteUtils.isLastLevelStatementBlock(sb) ) {
				ArrayList<Hop> tmp = collectCheckpoints(sb.getHops());
				for( Hop chkpoint : tmp ) {
					if( chkpointCand.containsKey(chkpoint.getName()) ) {
						chkpointCand.get(chkpoint.getName()).setRequiresCheckpoint(false);
					}
					chkpointCand.put(chkpoint.getName(), chkpoint);
				}
			}
		}
	}

	private static void moveCheckpointAfterUpdate(DMLProgram dmlp) {
		//approach: scan over top-level program (guaranteed to be unconditional),
		//collect checkpoints; determine if used before update; move first checkpoint
		//after update if not used before update (best effort move which often avoids
		//the second checkpoint on loops even though used in between)
		
		HashMap<String, Hop> chkpointCand = new HashMap<>();
		
		for( StatementBlock sb : dmlp.getStatementBlocks() ) 
		{
			//prune candidates (used before updated)
			Set<String> cands = new HashSet<>(chkpointCand.keySet());
			for( String cand : cands )
				if( sb.variablesRead().containsVariable(cand) 
					&& !sb.variablesUpdated().containsVariable(cand) ) 
				{	
					//note: variableRead might include false positives due to meta 
					//data operations like nrow(X) or operations removed by rewrites 
					//double check hops on basic blocks; otherwise worst-case
					boolean skipRemove = false;
					if( sb.getHops() !=null ) {
						Hop.resetVisitStatus(sb.getHops());
						skipRemove = true;
						for( Hop root : sb.getHops() )
							skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
					}					
					if( !skipRemove )
						chkpointCand.remove(cand);
				}
			
			//prune candidates (updated in conditional control flow)
			Set<String> cands2 = new HashSet<>(chkpointCand.keySet());
			if( sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock 
				|| sb instanceof ForStatementBlock )
			{
				for( String cand : cands2 )
					if( sb.variablesUpdated().containsVariable(cand) ) {
						chkpointCand.remove(cand);
					}
			}
			//move checkpoint after update with simple read chain 
			//(note: right now this only applies if the checkpoints comes from a previous
			//statement block, within-dag checkpoints should be handled during injection)
			else
			{
				for( String cand : cands2 )
					if( sb.variablesUpdated().containsVariable(cand) && sb.getHops() != null) {
						Hop.resetVisitStatus(sb.getHops());
						for( Hop root : sb.getHops() )
							if( root.getName().equals(cand) ) {
								if( HopRewriteUtils.rHasSimpleReadChain(root, cand) ) {
									chkpointCand.get(cand).setRequiresCheckpoint(false);
									root.getInput().get(0).setRequiresCheckpoint(true);
									chkpointCand.put(cand, root.getInput().get(0));
								}
								else
									chkpointCand.remove(cand);
							}
					}	
			}
		
			//collect checkpoints
			if( HopRewriteUtils.isLastLevelStatementBlock(sb) ) {
				ArrayList<Hop> tmp = collectCheckpoints(sb.getHops());
				for( Hop chkpoint : tmp )
					chkpointCand.put(chkpoint.getName(), chkpoint);
			}
		}
	}
	
	private static void removeCheckpointReadWrite(DMLProgram dmlp) {
		List<StatementBlock> sbs = dmlp.getStatementBlocks();

		if (sbs.size() == 1 && !(sbs.get(0) instanceof IfStatementBlock
				|| sbs.get(0) instanceof WhileStatementBlock
				|| sbs.get(0) instanceof ForStatementBlock)) {
			//recursively process all dag roots
			if (sbs.get(0).getHops() != null) {
				Hop.resetVisitStatus(sbs.get(0).getHops());
				for (Hop root : sbs.get(0).getHops())
					rRemoveCheckpointReadWrite(root);
			}
		}
	}
	
	private static ArrayList<Hop> collectCheckpoints(ArrayList<Hop> roots)
	{
		ArrayList<Hop> ret = new ArrayList<>();
		if( roots != null ) {
			Hop.resetVisitStatus(roots);
			for( Hop root : roots )
				rCollectCheckpoints(root, ret);
		}
		
		return ret;
	}
	
	private static void rCollectCheckpoints(Hop hop, ArrayList<Hop> checkpoints)
	{
		if( hop.isVisited() )
			return;

		//handle leaf node for variable (checkpoint directly bound
		//to logical variable name and not used)
		if( hop.requiresCheckpoint() && hop.getParent().size()==1 
			&& hop.getParent().get(0) instanceof DataOp
			&& ((DataOp)hop.getParent().get(0)).getOp()==OpOpData.TRANSIENTWRITE)
		{
			checkpoints.add(hop);
		}
		
		//recursively process child nodes
		for( Hop c : hop.getInput() )
			rCollectCheckpoints(c, checkpoints);
	
		hop.setVisited();
	}
	
	public static void rRemoveCheckpointReadWrite(Hop hop)
	{
		if( hop.isVisited() )
			return;

		//remove checkpoint on pread if only consumed by pwrite or uagg
		if( (hop instanceof DataOp && ((DataOp)hop).getOp()==OpOpData.PERSISTENTWRITE)
			|| hop instanceof AggUnaryOp )	
		{
			//(pwrite|uagg) - pread
			Hop c0 = hop.getInput().get(0);
			if( c0.requiresCheckpoint() && c0.getParent().size() == 1
				&& c0 instanceof DataOp && ((DataOp)c0).getOp()==OpOpData.PERSISTENTREAD )
			{
				c0.setRequiresCheckpoint(false);
			}
			
			//(pwrite|uagg) - frame/matri cast - pread
			if( c0 instanceof UnaryOp && c0.getParent().size() == 1 
				&& (((UnaryOp)c0).getOp()==OpOp1.CAST_AS_FRAME || ((UnaryOp)c0).getOp()==OpOp1.CAST_AS_MATRIX ) 
				&& c0.getInput().get(0).requiresCheckpoint() && c0.getInput().get(0).getParent().size() == 1
				&& c0.getInput().get(0) instanceof DataOp 
				&& ((DataOp)c0.getInput().get(0)).getOp()==OpOpData.PERSISTENTREAD )
			{
				c0.getInput().get(0).setRequiresCheckpoint(false);
			}
		}
		
		//recursively process children
		for( Hop c : hop.getInput() )
			rRemoveCheckpointReadWrite(c);
		
		hop.setVisited();
	}
}
