/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysds.hops;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpDnn;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.DnnTransform;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

import java.util.ArrayList;

public class DnnOp extends MultiThreadedHop
{
	// -------------------------------------------------------------------------
	// This flag allows us to compile plans with less unknowns and also serves as future tensorblock integration.
	// By default, these flags are turned on.
	
	// When this flag is turned on, we attempt to check the parent convolution hop for unknown dimensions.
	// For example: in case of conv -> maxpool, the input channel/height/width of maxpool will match output channel/height/width of conv.
	private static final boolean INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP = true;
	// This guards us from cases where the user provides incorrect C,H,W parameters.
	private static final boolean THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH = true;
	// -------------------------------------------------------------------------
	
	// Specifies the type of this hop
	private OpOpDnn op;

	private DnnOp() {
		//default constructor for clone
	}

	/**
	 * Create a hop from the builtin expression
	 * 
	 * @param l name of the hop
	 * @param dt datatype (only supports matrix datatype)
	 * @param vt valuetype  (only supports matrix valuetype) 
	 * @param o type of this hop
	 * @param inp input hops
	 */
	public DnnOp(String l, DataType dt, ValueType vt, OpOpDnn o, ArrayList<Hop> inp) 
	{
		super(l, dt, vt);
		op = o;
		
		for( int i=0; i<inp.size(); i++ ) {
			Hop in = inp.get(i);
			getInput().add(i, in);
			in.getParent().add(this);
		}
		
		//compute unknown dims and nnz
		refreshSizeInformation();
	}

	@Override
	public void checkArity() {
		HopsException.check(_input.size() >= 1, this, "should have at least one input but has %d inputs", _input.size());
	}

	public OpOpDnn getOp() {
		return op;
	}
	
	@Override
	public String getOpString() {
		return op.toString();
	}

	private static boolean isEligibleForSpark() {
		return false;
	}
	
	@Override
	public boolean isGPUEnabled() {
		if(!DMLScript.USE_ACCELERATOR)
			return false;
		return true;
	}
	
	@Override
	public boolean isMultiThreadedOpType() {
		return true;
	}
	
	@Override
	public Lop constructLops()
	{
		//return already created lops
		if( getLops() != null )
			return getLops();
		
		ExecType et = optFindExecType();
		
		ArrayList<Hop> inputs = getInput();
		switch( op )
		{
			case MAX_POOL:
			case MAX_POOL_BACKWARD:
			case AVG_POOL:
			case AVG_POOL_BACKWARD:
			case CONV2D:
			case CONV2D_BACKWARD_DATA:
			case CONV2D_BACKWARD_FILTER:
			case BIASADD:
			case BIASMULT: {
				if(et == ExecType.CP || et == ExecType.GPU) {
					setLops(constructDnnLops(et, inputs));
					break;
				}
				throw new HopsException("Unimplemented DnnOp for execution type: " + et.name());
			}
			case BATCH_NORM2D_TEST:
			case CHANNEL_SUMS:
			case UPDATE_NESTEROV_X: {
				if(et == ExecType.GPU) {
					setLops(constructDnnLops(et, inputs));
					break;
				}
				throw new HopsException("Unimplemented DnnOp for execution type: " + et.name());
			}
			default:
				throw new HopsException("Unsupported lops construction for operation type '"+op+"'.");
		}
		
		//add reblock/checkpoint lops if necessary
		constructAndSetLopsDataFlowProperties();
		
		return getLops();
	}
	
	public void setOp(OpOpDnn op) {
		this.op = op;
	}
	
	private int getNumExpectedInputs() {
		switch(op) {
			case MAX_POOL_BACKWARD:
			case AVG_POOL_BACKWARD:
			case CONV2D:
			case CONV2D_BACKWARD_FILTER:
			case CONV2D_BACKWARD_DATA:
				return 14;
			case BIASADD:
			case BIASMULT:
				return 2;
			case BATCH_NORM2D_TEST:
				return 6;
			case CHANNEL_SUMS:
				return 3;
			case UPDATE_NESTEROV_X:
				return 4;
			default:
				return 13;
		}
	}
	
	/**
	 * Returns parent matrix X or null
	 * @param input input hop
	 * @return either null or X if input is max(X,0) or max(0,X)
	 */
	private static Hop isInputReLU(Hop input) {
		if(HopRewriteUtils.isBinary(input, OpOp2.MAX)) {
			if(HopRewriteUtils.isLiteralOfValue(input.getInput().get(0), 0)) {
				return input.getInput().get(1);
			}
			else if(HopRewriteUtils.isLiteralOfValue(input.getInput().get(1), 0)) {
				return input.getInput().get(0);
			}
			else
				return null; 
		}
		else
			return null;
	}
	
	private static boolean isInputConv2d(Hop input) {
		return HopRewriteUtils.isDnn(input, OpOpDnn.CONV2D);
	}
	
	/**
	 * Compares the input parameters for max_pool/max_pool_backward operations
	 * 
	 * @return true if the following parameters match: stride=[stride, stride], padding=[pad, pad], input_shape=[numImg, numChannels, imgSize, imgSize], pool_size=[poolSize1, poolSize2]
	 */
	private static boolean isPoolingParametersEqualAndKnown(DnnParameters param1, DnnParameters param2) {
		return isEqualAndKnown(param1.stride_h, param2.stride_h) && isEqualAndKnown(param1.stride_w, param2.stride_w) && 
			isEqualAndKnown(param1.pad_h, param2.pad_h) && isEqualAndKnown(param1.pad_w, param2.pad_w) &&
			isEqualAndKnown(param1.R, param2.R) && isEqualAndKnown(param1.S, param2.S) &&
			isEqualAndKnown(param1.N, param2.N) && isEqualAndKnown(param1.C, param2.C) &&
			isEqualAndKnown(param1.H, param2.H) && isEqualAndKnown(param1.W, param2.W);
	}
	
	public boolean isStride1Pad0() {
		DnnParameters tmp = parseInput();
		return tmp.stride_h == 1 && tmp.stride_w == 1
			&& tmp.pad_h == 0 && tmp.pad_w == 0;
	}
	
	private static boolean isEqualAndKnown(int val1, int val2) {
		return val1 >= 0 && val2 >= 0 && val1 == val2;
	}
	
	/**
	 * Returns the output lop of max_pool/avg_pool operation with same parameters as this hop.
	 * If corresponding output lop is not found or if this is not a max_pool_backward operation, this function returns null
	 * 
	 * @return output lop of max_pool/avg_pool operation with same parameters as this hop
	 */
	private Lop getMaxPoolOutputLop() {
		if(op == OpOpDnn.MAX_POOL_BACKWARD || op == OpOpDnn.AVG_POOL_BACKWARD) {
			OpOpDnn opType = (op == OpOpDnn.MAX_POOL_BACKWARD) ? OpOpDnn.MAX_POOL : OpOpDnn.AVG_POOL;
			Hop inputImage = getInput().get(0);
			for(Hop tmpParent : inputImage.getParent()) {
				if(!(tmpParent instanceof DnnOp))
					continue;
				DnnOp parent = (DnnOp) tmpParent;
				if(parent.getOp() == opType && isPoolingParametersEqualAndKnown(parent._cachedParams, _cachedParams)) {
					return parent.constructLops();
				}
			}
		}
		return null;
	}
	
	public Lop constructDnnLops(ExecType et, ArrayList<Hop> inputs) {
		if(inputs.size() != getNumExpectedInputs()) 
			throw new HopsException("Incorrect number of inputs for " + op.name());
		
		//TODO move these custom rewrites to the general hop rewrites
		// ---------------------------------------------------------------
		// Deal with fused operators and contruct lhsInputLop/optionalRhsInputLop
		Lop lhsInputLop = null; Lop optionalRhsInputLop = null;
		ArrayList<Hop> inputsOfPotentiallyFusedOp = inputs;
		
		OpOpDnn lopOp = op;
		// RELU_MAX_POOLING and RELU_MAX_POOLING_BACKWARD is extremely useful for CP backend 
		// by reducing unnecessary sparse-to-dense-to-sparse conversion.
		// For other backends, this operators is not necessary as it reduces an additional relu operator.
		Hop parentReLU = isInputReLU(inputs.get(0));
		if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == OpOpDnn.MAX_POOL && parentReLU != null) {
			lhsInputLop = parentReLU.constructLops();
			lopOp = OpOpDnn.RELU_MAX_POOL;
		}
		else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == OpOpDnn.MAX_POOL_BACKWARD && parentReLU != null) {
			lhsInputLop = parentReLU.constructLops();
			lopOp = OpOpDnn.RELU_MAX_POOL_BACKWARD;
		}
		else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && op == OpOpDnn.BIASADD && isInputConv2d(inputs.get(0))) {
			lopOp = OpOpDnn.CONV2D_BIAS_ADD;
			
			// the first lop is image 
			lhsInputLop = inputs.get(0).getInput().get(0).constructLops();
			// the second lop is bias
			optionalRhsInputLop = inputs.get(1).constructLops();
			
			// Use the inputs from conv2d rather than bias_add
			inputsOfPotentiallyFusedOp = inputs.get(0).getInput();
		}
		else {
			lhsInputLop = inputs.get(0).constructLops();
		}
		// ---------------------------------------------------------------
		
		// ---------------------------------------------------------------
		// Compute intermediate memory budget that can be passed to GPU operators 
		// for better CuDNN operator selection at runtime
		double intermediateMemEstimate = computeIntermediateMemEstimate(-1, -1, -1 );
		if(et == ExecType.GPU && getDim1() >= 0 && getDim2() >= 0) {
			// This enables us to compile more efficient matrix-matrix CuDNN operation instead of 
			// row-by-row invocation of multiple vector-matrix CuDNN operations.
			// This is possible as the operations on GPU are single-threaded
			double optimisticIntermediateMemEstimate = GPUContextPool.initialGPUMemBudget() - getOutputMemEstimate() - inputs.get(0).getOutputMemEstimate();
			if(optionalRhsInputLop != null) {
				optimisticIntermediateMemEstimate -= inputs.get(1).getOutputMemEstimate();
			}
			intermediateMemEstimate = Math.max(intermediateMemEstimate, optimisticIntermediateMemEstimate);
		}
		// ---------------------------------------------------------------
		
		// Construct the lop
		Lop optionalMaxPoolOutput = (et == ExecType.GPU) ? getMaxPoolOutputLop() : null;
		Lop[] l2inputs = new Lop[inputsOfPotentiallyFusedOp.size()-1];
		for( int i=1; i < inputsOfPotentiallyFusedOp.size(); i++ )
			l2inputs[i-1] = inputsOfPotentiallyFusedOp.get(i).constructLops();
		DnnTransform convolutionLop = new DnnTransform(
			lhsInputLop, lopOp, getDataType(), getValueType(), et,
			OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), intermediateMemEstimate);
		setOutputDimensions(convolutionLop);
		setLineNumbers(convolutionLop);
		
		// ---------------------------------------------------------------
		// Add input/output for parent lops of convolutionLop
		lhsInputLop.addOutput(convolutionLop);
		if(optionalRhsInputLop != null) {
			convolutionLop.addInput(optionalRhsInputLop);
			optionalRhsInputLop.addOutput(convolutionLop);
		}
		for( int i=0; i < l2inputs.length; i++ ) {
			convolutionLop.addInput(l2inputs[i]);
			l2inputs[i].addOutput(convolutionLop);
		}
		// Only valid for MAX_POOLING_BACKWARD on GPU
		if(optionalMaxPoolOutput != null) {
			convolutionLop.addInput(optionalMaxPoolOutput);
			optionalMaxPoolOutput.addOutput(convolutionLop);
		}
		convolutionLop.updateLopProperties();
		
		// TODO double check that optionalMaxPoolOutput adheres to proper
		// ID ordering of constructed lops (previously hidden by setLevel)
		
		// ---------------------------------------------------------------
		
		return convolutionLop;
	}

			
	@Override
	protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
	{		
		if(getOp() == OpOpDnn.BIASMULT) {
			// in non-gpu mode, the worst case size of bias multiply operation is same as that of input.
			if(DMLScript.USE_ACCELERATOR) 
				return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
			else
				return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, getInput().get(0).getSparsity());
		}
		else {
			double sparsity = 1.0;
			return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
		}
	}
	
	// ---------------------------------------------------------------
	// Utility methods to guard the computation of memory estimates in presense of unknowns
	private static class IntermediateDimensions {
		int dim1; int dim2; double sp;
		public IntermediateDimensions(DnnOp h, String dim1Str, String dim2Str, double sp) {
			dim1 = (int) h.getDim(dim1Str);
			dim2 = (int) h.getDim(dim2Str);
			this.sp = sp;
		}
		public IntermediateDimensions(DnnOp h, String dim1Str, String dim2Str) {
			dim1 = (int) h.getDim(dim1Str);
			dim2 = (int) h.getDim(dim2Str);
			sp = 1;
		}
		public IntermediateDimensions(DnnOp h, int dim1, String dim2Str) {
			this.dim1 = dim1;
			dim2 = (int) h.getDim(dim2Str);
			sp = 1;
		}
		
		/**
		 * Add two computed memory estimates
		 * 
		 * @param val1 memory estimate 1
		 * @param val2 memory estimate 2
		 * @return sum of memory estimates
		 */
		static double guardedAdd(double val1, double val2) {
			if(val1 < 0 || val2 < 0) return OptimizerUtils.DEFAULT_SIZE;
			double ret = val1 + val2;
			if(ret >= OptimizerUtils.DEFAULT_SIZE) return OptimizerUtils.DEFAULT_SIZE;
			else return ret;
		}
		
		/**
		 * Compute memory estimates for given intermediate matrices 
		 * 
		 * @param intermediates list of intermediates
		 * @param numWorkers number of workers
		 * @return memory estimate
		 */
		public static double addEstimateSizes(ArrayList<IntermediateDimensions> intermediates, int numWorkers) {
			double memBudget = 0; 
			for(int i = 0; i < intermediates.size(); i++) {
				memBudget = guardedAdd(memBudget, OptimizerUtils.estimateSizeExactSparsity(
						intermediates.get(i).dim1, intermediates.get(i).dim2, intermediates.get(i).sp)*numWorkers);
			}
			return memBudget;
		}
		
		/**
		 * Compute max of two computed memory estimates
		 * @param val1 memory estimate 1
		 * @param val2 memory estimate 2
		 * @return max of memory estimates
		 */
		public static double guardedMax(double val1, double val2) {
			if(val1 < 0 || val2 < 0) return OptimizerUtils.DEFAULT_SIZE;
			double ret = Math.max(val1, val2);
			if(ret >= OptimizerUtils.DEFAULT_SIZE) return OptimizerUtils.DEFAULT_SIZE;
			else return ret;
		}
	}
	
	/**
	 * Helper utility to compute intermediate memory estimate
	 * 
	 * @param gpuIntermediates intermediates for GPU
	 * @param cpIntermediates intermediates for CP
	 * @return memory estimates
	 */
	private double computeIntermediateMemEstimateHelper(
			ArrayList<IntermediateDimensions> gpuIntermediates,
			ArrayList<IntermediateDimensions> cpIntermediates) {
		// Since CP operators use row-level parallelism by default
		int numWorkers = (int) Math.min(OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), Math.max(getDim("N"), 1));
		if(DMLScript.USE_ACCELERATOR) {
			// Account for potential sparse-to-dense conversion
			double gpuMemBudget = IntermediateDimensions.addEstimateSizes(gpuIntermediates, 1);
			double cpMemoryBudget = IntermediateDimensions.addEstimateSizes(cpIntermediates, numWorkers);
			if(cpMemoryBudget > gpuMemBudget) {
				double oneThreadCPMemBudget = IntermediateDimensions.addEstimateSizes(cpIntermediates, 1);
				if(oneThreadCPMemBudget <= gpuMemBudget) {
					// Why limit CPU ? in-order to give more opportunity to compile GPU operators
					cpMemoryBudget = oneThreadCPMemBudget;
				}
			}
			// Finally, use the maximum of CP and GPU memory budget
			return IntermediateDimensions.guardedMax(cpMemoryBudget, gpuMemBudget);
		}
		else {
			// When -gpu flag is not provided, the memory estimates for CP are not affected.
			return IntermediateDimensions.addEstimateSizes(cpIntermediates, numWorkers);
		}
	}
	
	@Override
	protected double computeIntermediateMemEstimate( long ignoreDim1, long ignoreDim2, long ignoreNnz )
	{	
		ArrayList<IntermediateDimensions> gpuIntermediates = new ArrayList<>();
		ArrayList<IntermediateDimensions> cpIntermediates = new ArrayList<>();
		if(getOp() == OpOpDnn.CONV2D) {
			// Assumption: To compile a GPU conv2d operator, following should fit on the GPU:
			// 1. output in dense format (i.e. computeOutputMemEstimate) 
			// 2. input in any format
			// 3. atleast one input row in dense format
			// 4. filter in dense format
			
			// Account for potential sparse-to-dense conversion of atleast 1 input row and filter
			gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
			gpuIntermediates.add(new IntermediateDimensions(this, "K", "CRS"));
			
			// im2col operation preserves the worst-case sparsity of the input.
			cpIntermediates.add(new IntermediateDimensions(this, "CRS", "PQ", getInput().get(0).getSparsity()));
		}
		else if(getOp() == OpOpDnn.CONV2D_BACKWARD_DATA) {
			// Assumption: To compile a GPU conv2d_backward_data operator, following should fit on the GPU:
			// 1. output in dense format (i.e. computeOutputMemEstimate) 
			// 2. dout in any format
			// 3. atleast one dout row in dense format
			// 4. filter in dense format
			
			// Account for potential sparse-to-dense conversion of atleast 1 input row and filter
			gpuIntermediates.add(new IntermediateDimensions(this, 1, "KPQ"));
			gpuIntermediates.add(new IntermediateDimensions(this, "K", "CRS"));
			
			// There are 2 intermediates: rotate180 and input to col2im for conv2d_backward_data
			// rotate180 preserves the "exact" sparsity of the dout matrix
			cpIntermediates.add(new IntermediateDimensions(this, "PQ", "K", getInput().get(1).getSparsity()));
			// Note: worst-case sparsity for the input of col2im (of size NPQ x CRS where N is determined by degree of parallelism)
			cpIntermediates.add(new IntermediateDimensions(this, "PQ", "CRS"));
		}
		else if(getOp() == OpOpDnn.CONV2D_BACKWARD_FILTER) {
			// Assumption: To compile a GPU conv2d_backward_filter operator, following should fit on the GPU:
			// 1. output in dense format (i.e. computeOutputMemEstimate) 
			// 2. dout in any format
			// 3. atleast one dout and input row in dense format
			
			// Account for potential sparse-to-dense conversion of atleast 1 input + dout row
			gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
			gpuIntermediates.add(new IntermediateDimensions(this, 1, "KPQ"));
			
			// There are 2 intermediates: im2col and rotate180 for conv2d_backward_filter
			// rotate180 preserves the "exact" sparsity of the dout matrix
			cpIntermediates.add(new IntermediateDimensions(this, "PQ", "K", getInput().get(1).getSparsity()));
			// im2col operation preserves the worst-case sparsity of the input.
			cpIntermediates.add(new IntermediateDimensions(this, "CRS", "PQ", getInput().get(0).getSparsity()));
		}
		else if(getOp() == OpOpDnn.MAX_POOL || getOp() == OpOpDnn.AVG_POOL) {
			// Account for potential sparse-to-dense conversion of at least 1 input row
			gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
		}
		else if(getOp() == OpOpDnn.MAX_POOL_BACKWARD || getOp() == OpOpDnn.AVG_POOL_BACKWARD) {
			// Account for potential sparse-to-dense conversion of at least 1 input + dout row
			gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
			gpuIntermediates.add(new IntermediateDimensions(this, 1, "CPQ"));
		}
		
		if(gpuIntermediates.size() > 0 || cpIntermediates.size() > 0)
			return computeIntermediateMemEstimateHelper(gpuIntermediates, cpIntermediates);
		else
			return 0;
	}
	
	
	@Override
	protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
	{
		// [numRows, numCols, NNZ] 
		DataCharacteristics ret = new MatrixCharacteristics();
		
		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
			op == OpOpDnn.UPDATE_NESTEROV_X) {
			// Same dimension as the first input
			DataCharacteristics[] mc = memo.getAllInputStats(getInput());
			ret = new MatrixCharacteristics(
				mc[0].rowsKnown() ? mc[0].getRows() : -1,
				mc[0].colsKnown() ? mc[0].getCols() : -1, -1, -1);
			return ret.dimsKnown() ? ret : null;
		}
		else if(op == OpOpDnn.CHANNEL_SUMS) {
			long numChannels = Hop.computeSizeInformation(getInput().get(1));
			return new MatrixCharacteristics(numChannels, 1, -1, -1);
		}
		
		//safe return (create entry only if at least dims known)
		refreshSizeInformation();
		ret = _dc;
		return ret.dimsKnown() ? ret : null;
	}
	

	@Override
	public boolean allowsAllExecTypes()
	{
		return true;
	}
	
	@Override
	protected ExecType optFindExecType() {
		
		checkAndSetForcedPlatform();
		
		if( _etypeForced != null ) {
			_etype = _etypeForced;
		}
		else {
			if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
				_etype = findExecTypeByMemEstimate();
			}
			else {
				_etype = ExecType.SPARK;
			}
			
			//check for valid CP dimensions and matrix size
			checkAndSetInvalidCPDimsAndSize();
		}
		
		// TODO: Fix this after adding remaining spark instructions
		_etype = !isEligibleForSpark() && _etype == ExecType.SPARK ?  ExecType.CP : _etype;
		
		//mark for recompile (forever)
		setRequiresRecompileIfNecessary();
		
		return _etype;
	}
	
	// Parameters recomputed in refreshSizeInformation and passed across many calls of getDim
	private DnnParameters _cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, _maxNumThreads);
	
	@SuppressWarnings("null")
	// stride1, stride2, padding1, padding2  
	// input_shape1, input_shape2, input_shape3, input_shape4, 
	// filter_shape1, filter_shape2, filter_shape3, filter_shape4
	DnnParameters parseInput() {
		
		Hop imageHeightHop = null; Hop filterHeightHop = null;
		if(op == OpOpDnn.MAX_POOL_BACKWARD || op == OpOpDnn.AVG_POOL_BACKWARD 
				|| op == OpOpDnn.CONV2D 
				|| op == OpOpDnn.CONV2D_BACKWARD_FILTER
				|| op == OpOpDnn.CONV2D_BACKWARD_DATA) {
			_cachedParams.setIfUnknown(
					getInput().get(6),  // N
					getInput().get(7),  // C
					getInput().get(8),  // H
					getInput().get(9),  // W
					getInput().get(10), // K
					getInput().get(12), // R
					getInput().get(13), // S
					getInput().get(2),  // stride_h
					getInput().get(3),  // stride_w
					getInput().get(4),  // pad+h
					getInput().get(5), _maxNumThreads);
		}
		else {
			_cachedParams.setIfUnknown(
					getInput().get(5),
					getInput().get(6),
					getInput().get(7),
					getInput().get(8),
					getInput().get(9),
					getInput().get(11),
					getInput().get(12),
					getInput().get(1),
					getInput().get(2),
					getInput().get(3),
					getInput().get(4), _maxNumThreads);
		}
		
		if(INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP) {
			boolean isPool = (getOp() == OpOpDnn.MAX_POOL || getOp() == OpOpDnn.AVG_POOL);
			boolean isConv = getOp() == OpOpDnn.CONV2D;
			boolean unknownCHWPQ = _cachedParams.C < 0 || _cachedParams.H < 0 || _cachedParams.W < 0 || _cachedParams.P < 0 || _cachedParams.Q < 0;
			if((isPool || isConv) && unknownCHWPQ) {
				// Only infer input shape for convolution and maxpool
				inferCHWPQFromParentOp();
			}
		}
		
		if(imageHeightHop == filterHeightHop && _cachedParams.R < 0 && _cachedParams.H > 0) {
			// Unknown R, but known H and both are equal
			// This happens for one-dimensional conv2d where H=R and H can be inferred from the parent hop
			_cachedParams.R = _cachedParams.H;
		}
		
		// Compute P and Q if unknown. At script level, they are computed using following script:
		// P = as.integer(floor((H + 2*pad_h - R)/stride_h + 1))
		// Q = as.integer(floor((W + 2*pad_w - S)/stride_w + 1))
		if(_cachedParams.P < 0 && _cachedParams.H >= 0 && _cachedParams.R >= 0 && _cachedParams.stride_h >= 0 && _cachedParams.pad_h >= 0) {
			_cachedParams.P = (int) org.apache.sysds.runtime.util.DnnUtils.getP(_cachedParams.H, _cachedParams.R, _cachedParams.stride_h, _cachedParams.pad_h);
		}
		if(_cachedParams.Q < 0 && _cachedParams.W >= 0 && _cachedParams.S >= 0 && _cachedParams.stride_w >= 0 && _cachedParams.pad_w >= 0) {
			_cachedParams.Q = (int) org.apache.sysds.runtime.util.DnnUtils.getQ(_cachedParams.W, _cachedParams.S, _cachedParams.stride_w, _cachedParams.pad_w);
		}
		
		return _cachedParams;
	}
	
	/**
	 * Utility method to check if the given hop is a BIAS_ADD hop
	 * 
	 * @param hop the given hop
	 * @return true if the given hop is BIAS_ADD
	 */
	private static boolean isInputBiasAdd(Hop hop) {
		return HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD);
	}
	
	/**
	 * Utility method to check if the inferred shapes are equal to the given shape with a guard for unknown
	 * 
	 * @param dim1 inferred shape
	 * @param dim2 given shape
	 * @param paramType string denoting the parameter for pretty printing of the error message
	 */
	private static void throwExceptionIfNotEqual(int dim1, int dim2, String paramType) {
		if(dim1 >= 0 && dim2 >= 0 && dim1 != dim2) {
			throw new DMLRuntimeException("Inferred " + paramType + " from parent doesn't match with given " + paramType + ":" + dim1 + " != " + dim2);
		}
	}
	
	/**
	 * Gets the values for the parameters C, H, W, P, Q from parent hops
	 */
	private void inferCHWPQFromParentOp() {
		Hop tmp = getInput().get(0);
		// Skip bias_add and go to its parent
		tmp = isInputBiasAdd(tmp) ? tmp.getInput().get(0) : tmp;
		Hop parentReLU = isInputReLU(tmp);
		// Skip ReLU and go to its parent
		tmp =  (parentReLU != null) ? parentReLU : tmp;
		
		// Cast tmp as parent
		DnnOp parentOp = (tmp instanceof DnnOp) ? ((DnnOp) tmp) : null; 
		
		if(parentOp == null)
			return;
		else if(parentOp.getOp() == OpOpDnn.MAX_POOL || parentOp.getOp() == OpOpDnn.AVG_POOL) {
			DnnParameters parentParam = parentOp.parseInput();
			int prevC = _cachedParams.C; int prevH = _cachedParams.H; int prevW = _cachedParams.W;
			// [C, P, Q] from maxpool becomes [C, H, W] of next op
			_cachedParams.C = (_cachedParams.C < 0) ? parentParam.C : _cachedParams.C;
			_cachedParams.H = (_cachedParams.H < 0) ? parentParam.P : _cachedParams.H;
			_cachedParams.W = (_cachedParams.W < 0) ? parentParam.Q : _cachedParams.W;
			if(LOG.isDebugEnabled()) {
				LOG.debug("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + _cachedParams.C + "," + _cachedParams.H + "," + _cachedParams.W + "]");
			}
			if(THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH) {
				throwExceptionIfNotEqual(prevC, _cachedParams.C, "C");
				throwExceptionIfNotEqual(prevH, _cachedParams.H, "H");
				throwExceptionIfNotEqual(prevW, _cachedParams.W, "W");
			}
		}
		else if(parentOp.getOp() == OpOpDnn.CONV2D) {
			DnnParameters parentParam = parentOp.parseInput();
			int prevC = _cachedParams.C; int prevH = _cachedParams.H; int prevW = _cachedParams.W;
			// [K, P, Q] from convolution becomes [C, H, W] of next op
			_cachedParams.C = (_cachedParams.C < 0) ? parentParam.K : _cachedParams.C;
			_cachedParams.H = (_cachedParams.H < 0) ? parentParam.P : _cachedParams.H;
			_cachedParams.W = (_cachedParams.W < 0) ? parentParam.Q : _cachedParams.W;
			if(LOG.isDebugEnabled()) {
				LOG.debug("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + _cachedParams.C + "," + _cachedParams.H + "," + _cachedParams.W + "]");
			}
			if(THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH) {
				throwExceptionIfNotEqual(prevC, _cachedParams.C, "C");
				throwExceptionIfNotEqual(prevH, _cachedParams.H, "H");
				throwExceptionIfNotEqual(prevW, _cachedParams.W, "W");
			}
		}
	}
	
	@Override
	public void refreshSizeInformation()
	{
		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT 
			|| op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) {
			// Same dimension as the first input
			Hop input1 = getInput().get(0);
			setDim1(input1.getDim1());
			setDim2(input1.getDim2());
			setNnz(-1); // cannot infer stats
			return;
		}
		else if(op == OpOpDnn.CHANNEL_SUMS) {
			long numChannels = Hop.computeSizeInformation(getInput().get(1));
			setDim1(numChannels);
			setDim2(1);
			setNnz(-1); // cannot infer stats
			return;
		}
		
		// Reset the _cachedParams to avoid incorrect sizes
		_cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, _maxNumThreads);
		
		switch(op) 
		{
			case MAX_POOL:
			case AVG_POOL: {
				setDim1(getDim("N"));
				setDim2(getDim("CPQ"));
				setNnz(-1); // cannot infer stats
				break;
			}
			case MAX_POOL_BACKWARD:
			case AVG_POOL_BACKWARD: {
				setDim1(getDim("N"));
				setDim2(getDim("CHW"));
				setNnz(-1);
				break;
			}
			case CONV2D: {
				setDim1(getDim("N"));
				setDim2(getDim("KPQ"));
				setNnz(-1); // cannot infer stats
				break;
			}
			case CONV2D_BACKWARD_DATA: {
				setDim1(getDim("N"));
				setDim2(getDim("CHW"));
				setNnz(-1); // cannot infer stats
				break;
			}
			case CONV2D_BACKWARD_FILTER: {
				setDim1(getDim("K"));
				setDim2(getDim("CRS"));
				setNnz(-1); // cannot infer stats
				break;
			}
			default:
				throw new RuntimeException("The sizes are not refreshed for " + op.name());
		}
	}
	
	@Override
	public Object clone() throws CloneNotSupportedException {
		DnnOp ret = new DnnOp();
		
		//copy generic attributes
		ret.clone(this, false);
		
		//copy specific attributes
		ret.op = op;
		ret._maxNumThreads = _maxNumThreads;
		return ret;
	}
	
	@Override
	public boolean compare( Hop that )
	{
		if( !(that instanceof DnnOp) )
			return false;
		
		DnnOp that2 = (DnnOp)that;
		
		boolean ret =  (op == that2.op)
				    && (getInput().size()==that.getInput().size())
				    && _maxNumThreads == that2._maxNumThreads;
		
		//compare all childs
		if( ret ) //sizes matched
			for( int i=0; i<_input.size(); i++ )
				ret &= getInput().get(i) == that2.getInput().get(i);
		
		return ret;
	}
	
	// ------------------------------------------------------------------------------------------------------
	// Utility methods to get the dimensions taking into account unknown dimensions
	
	/**
	 * Convenient method to get the dimensions required by ConvolutionOp.
	 * 
	 * @param dimString can be K, CRS, N, CHW, KPQ, PQ
	 * @return either -1 or value associated with the dimString
	 */
	private long getDim(String dimString) {
		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT
			|| op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
			op == OpOpDnn.UPDATE_NESTEROV_X) {
			throw new RuntimeException("getDim method should not be invoked for " + op.name());
		}
		try {
			parseInput();
		} catch (DMLRuntimeException e) {
			throw new RuntimeException(e);
		}
		Hop filter = null; 	// shape: K x CRS 
		Hop input = null; 	// shape: N x CHW
		Hop dout = null;	// shape: N x KPQ
		Hop dout1 = null;	// shape: N x CPQ
		
		if(getOp() == OpOpDnn.CONV2D) {
			input  = getInput().get(0);
			filter = getInput().get(1);
		}
		else if(getOp() == OpOpDnn.CONV2D_BACKWARD_DATA) {
			filter = getInput().get(0);
			dout  = getInput().get(1);
		}
		else if(getOp() == OpOpDnn.CONV2D_BACKWARD_FILTER) {
			input = getInput().get(0);
			dout  = getInput().get(1);
		}
		else if(getOp() == OpOpDnn.MAX_POOL || getOp() == OpOpDnn.AVG_POOL) {
			input = getInput().get(0);
		}
		else if(getOp() == OpOpDnn.MAX_POOL_BACKWARD || getOp() == OpOpDnn.AVG_POOL_BACKWARD) {
			input = getInput().get(0);
			dout1  = getInput().get(1);
		}
		
		long ret = -1;
		if(dimString.equals("K") && filter != null) {
			ret = getNonNegative(ret, getNonNegative(_cachedParams.K, filter.getDim1()));
		}
		else if(dimString.equals("CRS") && filter != null) {
			ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.R, _cachedParams.S), filter.getDim2()));
		}
		else if(dimString.equals("N") && input != null) {
			ret = getNonNegative(ret, getNonNegative(_cachedParams.N, input.getDim1()));
		}
		else if(dimString.equals("CHW") && input != null) {
			ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.H, _cachedParams.W), input.getDim2()));
		}
		else if(dimString.equals("N") && dout != null) {
			ret = getNonNegative(ret, getNonNegative(_cachedParams.N, dout.getDim1()));
		}
		else if(dimString.equals("KPQ") && dout != null) {
			ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.K, _cachedParams.P, _cachedParams.Q), dout.getDim2()));
		}
		else if(dimString.equals("N") && dout1 != null) {
			ret = getNonNegative(ret, getNonNegative(_cachedParams.N, dout1.getDim1()));
		}
		else if(dimString.equals("CPQ") && dout1 != null) {
			ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.P, _cachedParams.Q), dout1.getDim2()));
		}
		else if(dimString.equals("K")) {
			ret = getNonNegative(ret, _cachedParams.K >= 0 ? _cachedParams.K : -1);
		}
		else if(dimString.equals("CRS")) {
			ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.R, _cachedParams.S));
		}
		else if(dimString.equals("N")) {
			ret = getNonNegative(ret, _cachedParams.N >= 0 ? _cachedParams.N : -1);
		}
		else if(dimString.equals("CHW")) {
			ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.H, _cachedParams.W));
		}
		else if(dimString.equals("KPQ")) {
			ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.K, _cachedParams.P, _cachedParams.Q));
		}
		else if(dimString.equals("PQ")) {
			ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.P, _cachedParams.Q));
		}
		else if(dimString.equals("CPQ")) {
			ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.P, _cachedParams.Q));
		}
		else {
			throw new RuntimeException("Unsupported dimension:" + dimString + " for operator " + getOp().name());
		}
		
		if(LOG.isDebugEnabled() && ret < 0) {
			LOG.debug("Unknown dimension " + dimString + " for DnnOp:" + op.name() + 
					" img_dim=[" + _cachedParams.N + " " + _cachedParams.C + " " + _cachedParams.H + " " + _cachedParams.W + "]" +
					" filter_dim=[" + _cachedParams.K + " " + _cachedParams.C + " " + _cachedParams.R + " " + _cachedParams.S + "]" + 
					" output_feature_map=[" + _cachedParams.P + " " + _cachedParams.Q + "] stride=[" + _cachedParams.stride_h + " " + _cachedParams.stride_w + "]" +
					" pad=[" + _cachedParams.pad_h + " " + _cachedParams.pad_w + "]");
		}
		return ret;
	}
	
	private static long nonNegativeMultiply(long val1, long val2, long val3) {
		if(val1 >= 0 && val2 >= 0 && val3 >= 0) {
			return val1 * val2 * val3;
		}
		else return -1;
	}
	private static long nonNegativeMultiply(long val1, long val2) {
		if(val1 >= 0 && val2 >= 0) {
			return val1 * val2;
		}
		else return -1;
	}
	private static long getNonNegative(long val1, long val2) {
		if(val1 >= 0 && val2 >= 0) {
			if(val1 == val2) return val1;
			else throw new RuntimeException("Incorrect dimensions in DnnOp: " + val1 + " != " + val2);
		}
		else if(val1 >= 0) return val1;
		else if(val2 >= 0) return val2;
		else return -1;
	}
	// ------------------------------------------------------------------------------------------------------
}
