| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysml.parser; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.Iterator; |
| |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.sysml.conf.ConfigurationManager; |
| import org.apache.sysml.hops.AggBinaryOp; |
| import org.apache.sysml.hops.AggUnaryOp; |
| import org.apache.sysml.hops.BinaryOp; |
| import org.apache.sysml.hops.DataGenOp; |
| import org.apache.sysml.hops.DataOp; |
| import org.apache.sysml.hops.FunctionOp; |
| import org.apache.sysml.hops.FunctionOp.FunctionType; |
| import org.apache.sysml.hops.Hop; |
| import org.apache.sysml.hops.Hop.AggOp; |
| import org.apache.sysml.hops.Hop.DataGenMethod; |
| import org.apache.sysml.hops.Hop.DataOpTypes; |
| import org.apache.sysml.hops.Hop.Direction; |
| import org.apache.sysml.hops.Hop.OpOp2; |
| import org.apache.sysml.hops.Hop.OpOp3; |
| import org.apache.sysml.hops.Hop.ParamBuiltinOp; |
| import org.apache.sysml.hops.Hop.ReOrgOp; |
| import org.apache.sysml.hops.HopsException; |
| import org.apache.sysml.hops.IndexingOp; |
| import org.apache.sysml.hops.LeftIndexingOp; |
| import org.apache.sysml.hops.LiteralOp; |
| import org.apache.sysml.hops.MemoTable; |
| import org.apache.sysml.hops.OptimizerUtils; |
| import org.apache.sysml.hops.ParameterizedBuiltinOp; |
| import org.apache.sysml.hops.ReorgOp; |
| import org.apache.sysml.hops.TernaryOp; |
| import org.apache.sysml.hops.UnaryOp; |
| import org.apache.sysml.hops.ipa.InterProceduralAnalysis; |
| import org.apache.sysml.hops.rewrite.ProgramRewriter; |
| import org.apache.sysml.hops.recompile.Recompiler; |
| import org.apache.sysml.lops.Lop; |
| import org.apache.sysml.lops.LopsException; |
| import org.apache.sysml.parser.Expression.DataType; |
| import org.apache.sysml.parser.Expression.FormatType; |
| import org.apache.sysml.parser.Expression.ParameterizedBuiltinFunctionOp; |
| import org.apache.sysml.parser.Expression.ValueType; |
| import org.apache.sysml.parser.PrintStatement.PRINTTYPE; |
| import org.apache.sysml.runtime.DMLRuntimeException; |
| import org.apache.sysml.hops.ConvolutionOp; |
| import org.apache.sysml.hops.rewrite.HopRewriteUtils; |
| import org.apache.sysml.parser.Expression.BuiltinFunctionOp; |
| |
| |
| public class DMLTranslator |
| { |
| private static final Log LOG = LogFactory.getLog(DMLTranslator.class.getName()); |
| private DMLProgram _dmlProg = null; |
| |
| public DMLTranslator(DMLProgram dmlp) |
| throws DMLRuntimeException |
| { |
| _dmlProg = dmlp; |
| |
| //setup default size for unknown dimensions |
| OptimizerUtils.resetDefaultSize(); |
| //reinit rewriter according to opt level flags |
| Recompiler.reinitRecompiler(); |
| } |
| |
| /** |
| * Validate parse tree |
| * |
| * @throws LanguageException |
| * @throws IOException |
| */ |
| public void validateParseTree(DMLProgram dmlp) |
| throws LanguageException, ParseException, IOException |
| { |
| //STEP1: Pre-processing steps for validate - e.g., prepare read-after-write meta data |
| boolean fWriteRead = prepareReadAfterWrite(dmlp, new HashMap<String, DataIdentifier>()); |
| |
| //STEP2: Actual Validate |
| // handle functions in namespaces (current program has default namespace) |
| for (String namespaceKey : dmlp.getNamespaces().keySet()){ |
| |
| // for each function defined in the namespace |
| for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { |
| FunctionStatementBlock fblock = dmlp.getFunctionStatementBlock(namespaceKey,fname); |
| |
| HashMap<String, ConstIdentifier> constVars = new HashMap<String, ConstIdentifier>(); |
| VariableSet vs = new VariableSet(); |
| |
| // add the input variables for the function to input variable list |
| FunctionStatement fstmt = (FunctionStatement)fblock.getStatement(0); |
| if (fblock.getNumStatements() > 1){ |
| LOG.error(fstmt.printErrorLocation() + "FunctionStatementBlock can only have 1 FunctionStatement"); |
| throw new LanguageException(fstmt.printErrorLocation() + "FunctionStatementBlock can only have 1 FunctionStatement"); |
| } |
| |
| for (DataIdentifier currVar : fstmt.getInputParams()) { |
| |
| if (currVar.getDataType() == DataType.SCALAR){ |
| currVar.setDimensions(0, 0); |
| } |
| |
| vs.addVariable(currVar.getName(), currVar); |
| } |
| fblock.validate(dmlp, vs, constVars, false); |
| } |
| |
| } |
| |
| // handle regular blocks -- "main" program |
| VariableSet vs = new VariableSet(); |
| HashMap<String, ConstIdentifier> constVars = new HashMap<String, ConstIdentifier>(); |
| for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { |
| StatementBlock sb = dmlp.getStatementBlock(i); |
| vs = sb.validate(dmlp, vs, constVars, fWriteRead); |
| constVars = sb.getConstOut(); |
| } |
| |
| //STEP3: Post-processing steps after validate - e.g., prepare read-after-write meta data |
| if( fWriteRead ) |
| { |
| //propagate size and datatypes into read |
| prepareReadAfterWrite(dmlp, new HashMap<String, DataIdentifier>()); |
| |
| //re-validate main program for datatype propagation |
| vs = new VariableSet(); |
| constVars = new HashMap<String, ConstIdentifier>(); |
| for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { |
| StatementBlock sb = dmlp.getStatementBlock(i); |
| vs = sb.validate(dmlp, vs, constVars, fWriteRead); |
| constVars = sb.getConstOut(); |
| } |
| } |
| |
| return; |
| } |
| |
| public void liveVariableAnalysis(DMLProgram dmlp) throws LanguageException { |
| |
| // for each namespace, handle function program blocks -- forward direction |
| for (String namespaceKey : dmlp.getNamespaces().keySet()) { |
| for (String fname: dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { |
| FunctionStatementBlock fsb = dmlp.getFunctionStatementBlock(namespaceKey, fname); |
| FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); |
| |
| // perform function inlining |
| fstmt.setBody(StatementBlock.mergeFunctionCalls(fstmt.getBody(), dmlp)); |
| |
| VariableSet activeIn = new VariableSet(); |
| for (DataIdentifier id : fstmt.getInputParams()){ |
| activeIn.addVariable(id.getName(), id); |
| } |
| fsb.initializeforwardLV(activeIn); |
| } |
| } |
| |
| // for each namespace, handle function program blocks -- backward direction |
| for (String namespaceKey : dmlp.getNamespaces().keySet()) { |
| for (String fname: dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { |
| |
| // add output variables to liveout / activeout set |
| FunctionStatementBlock fsb = dmlp.getFunctionStatementBlock(namespaceKey, fname); |
| VariableSet currentLiveOut = new VariableSet(); |
| VariableSet currentLiveIn = new VariableSet(); |
| FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); |
| |
| for (DataIdentifier id : fstmt.getInputParams()) |
| currentLiveIn.addVariable(id.getName(), id); |
| |
| for (DataIdentifier id : fstmt.getOutputParams()) |
| currentLiveOut.addVariable(id.getName(), id); |
| |
| fsb._liveOut = currentLiveOut; |
| fsb.analyze(currentLiveIn, currentLiveOut); |
| } |
| } |
| |
| |
| // handle regular program blocks |
| VariableSet currentLiveOut = new VariableSet(); |
| VariableSet activeIn = new VariableSet(); |
| |
| // handle function inlining |
| dmlp.setStatementBlocks(StatementBlock.mergeFunctionCalls(dmlp.getStatementBlocks(), dmlp)); |
| |
| for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { |
| StatementBlock sb = dmlp.getStatementBlock(i); |
| activeIn = sb.initializeforwardLV(activeIn); |
| } |
| |
| if (dmlp.getNumStatementBlocks() > 0){ |
| StatementBlock lastSb = dmlp.getStatementBlock(dmlp.getNumStatementBlocks() - 1); |
| lastSb._liveOut = new VariableSet(); |
| for (int i = dmlp.getNumStatementBlocks() - 1; i >= 0; i--) { |
| StatementBlock sb = dmlp.getStatementBlock(i); |
| currentLiveOut = sb.analyze(currentLiveOut); |
| } |
| } |
| return; |
| |
| } |
| |
| /** |
| * Construct Hops from parse tree |
| * |
| * @throws ParseException |
| */ |
| public void constructHops(DMLProgram dmlp) |
| throws ParseException, LanguageException |
| { |
| // Step 1: construct hops for all functions |
| // for each namespace, handle function program blocks |
| for (String namespaceKey : dmlp.getNamespaces().keySet()){ |
| for (String fname: dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { |
| FunctionStatementBlock current = dmlp.getFunctionStatementBlock(namespaceKey, fname); |
| constructHops(current); |
| } |
| } |
| |
| // Step 2: construct hops for main program |
| // handle regular program blocks |
| for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { |
| StatementBlock current = dmlp.getStatementBlock(i); |
| constructHops(current); |
| } |
| } |
| |
| /** |
| * |
| * @param dmlp |
| * @throws ParseException |
| * @throws LanguageException |
| * @throws HopsException |
| */ |
| public void rewriteHopsDAG(DMLProgram dmlp) |
| throws ParseException, LanguageException, HopsException |
| { |
| //apply hop rewrites (static rewrites) |
| ProgramRewriter rewriter = new ProgramRewriter(true, false); |
| rewriter.rewriteProgramHopDAGs(dmlp); |
| resetHopsDAGVisitStatus(dmlp); |
| |
| //propagate size information from main into functions (but conservatively) |
| if( OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS ) { |
| InterProceduralAnalysis ipa = new InterProceduralAnalysis(); |
| ipa.analyzeProgram(dmlp); |
| resetHopsDAGVisitStatus(dmlp); |
| } |
| |
| //apply hop rewrites (dynamic rewrites, after IPA) |
| ProgramRewriter rewriter2 = new ProgramRewriter(false, true); |
| rewriter2.rewriteProgramHopDAGs(dmlp); |
| resetHopsDAGVisitStatus(dmlp); |
| |
| // Compute memory estimates for all the hops. These estimates are used |
| // subsequently in various optimizations, e.g. CP vs. MR scheduling and parfor. |
| refreshMemEstimates(dmlp); |
| resetHopsDAGVisitStatus(dmlp); |
| } |
| |
| |
| public void constructLops(DMLProgram dmlp) throws ParseException, LanguageException, HopsException, LopsException { |
| |
| // for each namespace, handle function program blocks handle function |
| for (String namespaceKey : dmlp.getNamespaces().keySet()){ |
| for (String fname: dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { |
| FunctionStatementBlock current = dmlp.getFunctionStatementBlock(namespaceKey, fname); |
| constructLops(current); |
| } |
| } |
| |
| // handle regular program blocks |
| for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { |
| StatementBlock current = dmlp.getStatementBlock(i); |
| constructLops(current); |
| } |
| } |
| |
| /** |
| * |
| * @param sb |
| * @throws HopsException |
| * @throws LopsException |
| */ |
| public void constructLops(StatementBlock sb) |
| throws HopsException, LopsException |
| { |
| if (sb instanceof WhileStatementBlock) |
| { |
| WhileStatementBlock wsb = (WhileStatementBlock)sb; |
| WhileStatement whileStmt = (WhileStatement)wsb.getStatement(0); |
| ArrayList<StatementBlock> body = whileStmt.getBody(); |
| |
| if (sb.get_hops() != null && !sb.get_hops().isEmpty()) { |
| LOG.error(sb.printBlockErrorLocation() + "WhileStatementBlock should not have hops"); |
| throw new HopsException(sb.printBlockErrorLocation() + "WhileStatementBlock should not have hops"); |
| } |
| // step through stmt blocks in while stmt body |
| for (StatementBlock stmtBlock : body){ |
| constructLops(stmtBlock); |
| } |
| |
| // handle while stmt predicate |
| Lop l = wsb.getPredicateHops().constructLops(); |
| wsb.set_predicateLops(l); |
| wsb.updatePredicateRecompilationFlag(); |
| } |
| |
| else if (sb instanceof IfStatementBlock) |
| { |
| IfStatementBlock isb = (IfStatementBlock) sb; |
| IfStatement ifStmt = (IfStatement)isb.getStatement(0); |
| ArrayList<StatementBlock> ifBody = ifStmt.getIfBody(); |
| ArrayList<StatementBlock> elseBody = ifStmt.getElseBody(); |
| |
| if (sb.get_hops() != null && !sb.get_hops().isEmpty()){ |
| LOG.error(sb.printBlockErrorLocation() + "IfStatementBlock should not have hops"); |
| throw new HopsException(sb.printBlockErrorLocation() + "IfStatementBlock should not have hops"); |
| } |
| // step through stmt blocks in if stmt ifBody |
| for (StatementBlock stmtBlock : ifBody) |
| constructLops(stmtBlock); |
| |
| // step through stmt blocks in if stmt elseBody |
| for (StatementBlock stmtBlock : elseBody) |
| constructLops(stmtBlock); |
| |
| // handle if stmt predicate |
| Lop l = isb.getPredicateHops().constructLops(); |
| isb.set_predicateLops(l); |
| isb.updatePredicateRecompilationFlag(); |
| } |
| |
| else if (sb instanceof ForStatementBlock) //NOTE: applies to ForStatementBlock and ParForStatementBlock |
| { |
| ForStatementBlock fsb = (ForStatementBlock) sb; |
| ForStatement fs = (ForStatement)sb.getStatement(0); |
| ArrayList<StatementBlock> body = fs.getBody(); |
| |
| if (sb.get_hops() != null && !sb.get_hops().isEmpty() ) { |
| LOG.error(sb.printBlockErrorLocation() + "ForStatementBlock should not have hops"); |
| throw new HopsException(sb.printBlockErrorLocation() + "ForStatementBlock should not have hops"); |
| } |
| // step through stmt blocks in FOR stmt body |
| for (StatementBlock stmtBlock : body) |
| constructLops(stmtBlock); |
| |
| // handle for stmt predicate |
| if (fsb.getFromHops() != null){ |
| Lop llobs = fsb.getFromHops().constructLops(); |
| fsb.setFromLops(llobs); |
| } |
| if (fsb.getToHops() != null){ |
| Lop llobs = fsb.getToHops().constructLops(); |
| fsb.setToLops(llobs); |
| } |
| if (fsb.getIncrementHops() != null){ |
| Lop llobs = fsb.getIncrementHops().constructLops(); |
| fsb.setIncrementLops(llobs); |
| } |
| fsb.updatePredicateRecompilationFlags(); |
| } |
| else if (sb instanceof FunctionStatementBlock){ |
| FunctionStatement functStmt = (FunctionStatement)sb.getStatement(0); |
| ArrayList<StatementBlock> body = functStmt.getBody(); |
| |
| if (sb.get_hops() != null && !sb.get_hops().isEmpty()) { |
| LOG.error(sb.printBlockErrorLocation() + "FunctionStatementBlock should not have hops"); |
| throw new HopsException(sb.printBlockErrorLocation() + "FunctionStatementBlock should not have hops"); |
| } |
| // step through stmt blocks in while stmt body |
| for (StatementBlock stmtBlock : body){ |
| constructLops(stmtBlock); |
| } |
| } |
| |
| // handle default case for regular StatementBlock |
| else { |
| |
| if (sb.get_hops() == null) |
| sb.set_hops(new ArrayList<Hop>()); |
| |
| ArrayList<Lop> lops = new ArrayList<Lop>(); |
| for (Hop hop : sb.get_hops()) { |
| lops.add(hop.constructLops()); |
| } |
| sb.setLops(lops); |
| sb.updateRecompilationFlag(); |
| } |
| |
| } // end method |
| |
| |
| public void printLops(DMLProgram dmlp) throws ParseException, LanguageException, HopsException, LopsException { |
| if (LOG.isDebugEnabled()){ |
| // for each namespace, handle function program blocks |
| for (String namespaceKey : dmlp.getNamespaces().keySet()){ |
| for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){ |
| FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname); |
| printLops(fsblock); |
| } |
| } |
| |
| for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { |
| StatementBlock current = dmlp.getStatementBlock(i); |
| printLops(current); |
| } |
| } |
| } |
| |
| public void printLops(StatementBlock current) throws ParseException, HopsException, LopsException { |
| if (LOG.isDebugEnabled()){ |
| ArrayList<Lop> lopsDAG = current.getLops(); |
| |
| LOG.debug("\n********************** LOPS DAG FOR BLOCK *******************"); |
| |
| if (current instanceof FunctionStatementBlock) { |
| if (current.getNumStatements() > 1) |
| LOG.debug("Function statement block has more than 1 stmt"); |
| FunctionStatement fstmt = (FunctionStatement)current.getStatement(0); |
| for (StatementBlock child : fstmt.getBody()){ |
| printLops(child); |
| } |
| } |
| |
| if (current instanceof WhileStatementBlock) { |
| |
| // print predicate lops |
| WhileStatementBlock wstb = (WhileStatementBlock) current; |
| Hop predicateHops = ((WhileStatementBlock) current).getPredicateHops(); |
| LOG.debug("\n********************** PREDICATE LOPS *******************"); |
| Lop predicateLops = predicateHops.getLops(); |
| if (predicateLops == null) |
| predicateLops = predicateHops.constructLops(); |
| predicateLops.printMe(); |
| |
| if (wstb.getNumStatements() > 1){ |
| LOG.error(wstb.printBlockErrorLocation() + "WhileStatementBlock has more than 1 statement"); |
| throw new HopsException(wstb.printBlockErrorLocation() + "WhileStatementBlock has more than 1 statement"); |
| } |
| WhileStatement ws = (WhileStatement)wstb.getStatement(0); |
| |
| for (StatementBlock sb : ws.getBody()){ |
| printLops(sb); |
| } |
| } |
| |
| if (current instanceof IfStatementBlock) { |
| |
| // print predicate lops |
| IfStatementBlock istb = (IfStatementBlock) current; |
| Hop predicateHops = ((IfStatementBlock) current).getPredicateHops(); |
| LOG.debug("\n********************** PREDICATE LOPS *******************"); |
| Lop predicateLops = predicateHops.getLops(); |
| if (predicateLops == null) |
| predicateLops = predicateHops.constructLops(); |
| predicateLops.printMe(); |
| |
| if (istb.getNumStatements() > 1){ |
| LOG.error(istb.printBlockErrorLocation() + "IfStatmentBlock has more than 1 statement"); |
| throw new HopsException(istb.printBlockErrorLocation() + "IfStatmentBlock has more than 1 statement"); |
| } |
| IfStatement is = (IfStatement)istb.getStatement(0); |
| |
| LOG.debug("\n**** LOPS DAG FOR IF BODY ****"); |
| for (StatementBlock sb : is.getIfBody()){ |
| printLops(sb); |
| } |
| if ( !is.getElseBody().isEmpty() ){ |
| LOG.debug("\n**** LOPS DAG FOR IF BODY ****"); |
| for (StatementBlock sb : is.getElseBody()){ |
| printLops(sb); |
| } |
| } |
| } |
| |
| if (current instanceof ForStatementBlock) { |
| |
| // print predicate lops |
| ForStatementBlock fsb = (ForStatementBlock) current; |
| LOG.debug("\n********************** PREDICATE LOPS *******************"); |
| if( fsb.getFromHops() != null ){ |
| LOG.debug("FROM:"); |
| Lop llops = fsb.getFromLops(); |
| if( llops == null ) |
| llops = fsb.getFromHops().constructLops(); |
| llops.printMe(); |
| } |
| if( fsb.getToHops() != null ){ |
| LOG.debug("TO:"); |
| Lop llops = fsb.getToLops(); |
| if( llops == null ) |
| llops = fsb.getToHops().constructLops(); |
| llops.printMe(); |
| } |
| if( fsb.getIncrementHops() != null ){ |
| LOG.debug("INCREMENT:"); |
| Lop llops = fsb.getIncrementLops(); |
| if( llops == null ) |
| llops = fsb.getIncrementHops().constructLops(); |
| llops.printMe(); |
| } |
| |
| if (fsb.getNumStatements() > 1){ |
| LOG.error(fsb.printBlockErrorLocation() + "ForStatementBlock has more than 1 statement"); |
| throw new HopsException(fsb.printBlockErrorLocation() + "ForStatementBlock has more than 1 statement"); |
| } |
| ForStatement ws = (ForStatement)fsb.getStatement(0); |
| |
| for (StatementBlock sb : ws.getBody()){ |
| printLops(sb); |
| } |
| } |
| |
| if (lopsDAG != null && !lopsDAG.isEmpty() ) { |
| Iterator<Lop> iter = lopsDAG.iterator(); |
| while (iter.hasNext()) { |
| LOG.debug("\n********************** OUTPUT LOPS *******************"); |
| iter.next().printMe(); |
| } |
| } |
| } |
| } |
| |
| |
| public void printHops(DMLProgram dmlp) throws ParseException, LanguageException, HopsException { |
| if (LOG.isDebugEnabled()) { |
| // for each namespace, handle function program blocks |
| for (String namespaceKey : dmlp.getNamespaces().keySet()){ |
| for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){ |
| FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname); |
| printHops(fsblock); |
| } |
| } |
| |
| // hand |
| for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { |
| StatementBlock current = dmlp.getStatementBlock(i); |
| printHops(current); |
| } |
| } |
| } |
| |
| public void printHops(StatementBlock current) throws ParseException, HopsException { |
| if (LOG.isDebugEnabled()) { |
| ArrayList<Hop> hopsDAG = current.get_hops(); |
| LOG.debug("\n********************** HOPS DAG FOR BLOCK *******************"); |
| |
| if (current instanceof FunctionStatementBlock) { |
| if (current.getNumStatements() > 1) |
| LOG.debug("Function statement block has more than 1 stmt"); |
| FunctionStatement fstmt = (FunctionStatement)current.getStatement(0); |
| for (StatementBlock child : fstmt.getBody()){ |
| printHops(child); |
| } |
| } |
| |
| if (current instanceof WhileStatementBlock) { |
| |
| // print predicate hops |
| WhileStatementBlock wstb = (WhileStatementBlock) current; |
| Hop predicateHops = wstb.getPredicateHops(); |
| LOG.debug("\n********************** PREDICATE HOPS *******************"); |
| predicateHops.printMe(); |
| |
| if (wstb.getNumStatements() > 1) |
| LOG.debug("While statement block has more than 1 stmt"); |
| WhileStatement ws = (WhileStatement)wstb.getStatement(0); |
| |
| for (StatementBlock sb : ws.getBody()){ |
| printHops(sb); |
| } |
| } |
| |
| if (current instanceof IfStatementBlock) { |
| |
| // print predicate hops |
| IfStatementBlock istb = (IfStatementBlock) current; |
| Hop predicateHops = istb.getPredicateHops(); |
| LOG.debug("\n********************** PREDICATE HOPS *******************"); |
| predicateHops.printMe(); |
| |
| |
| if (istb.getNumStatements() > 1) |
| LOG.debug("If statement block has more than 1 stmt"); |
| IfStatement is = (IfStatement)istb.getStatement(0); |
| |
| for (StatementBlock sb : is.getIfBody()){ |
| printHops(sb); |
| } |
| |
| for (StatementBlock sb : is.getElseBody()){ |
| printHops(sb); |
| } |
| } |
| |
| |
| if (current instanceof ForStatementBlock) { |
| |
| // print predicate hops |
| ForStatementBlock fsb = (ForStatementBlock) current; |
| LOG.debug("\n********************** PREDICATE HOPS *******************"); |
| if (fsb.getFromHops() != null) fsb.getFromHops().printMe(); |
| if (fsb.getToHops() != null) fsb.getToHops().printMe(); |
| if (fsb.getIncrementHops() != null) fsb.getIncrementHops().printMe(); |
| |
| if (fsb.getNumStatements() > 1) |
| LOG.debug("For statement block has more than 1 stmt"); |
| ForStatement ws = (ForStatement)fsb.getStatement(0); |
| |
| for (StatementBlock sb : ws.getBody()){ |
| printHops(sb); |
| } |
| } |
| |
| if (hopsDAG != null && !hopsDAG.isEmpty()) { |
| // hopsDAG.iterator().next().printMe(); |
| Iterator<Hop> iter = hopsDAG.iterator(); |
| while (iter.hasNext()) { |
| LOG.debug("\n********************** OUTPUT HOPS *******************"); |
| iter.next().printMe(); |
| } |
| } |
| } |
| } |
| |
| public void refreshMemEstimates(DMLProgram dmlp) throws ParseException, LanguageException, HopsException { |
| |
| // for each namespace, handle function program blocks -- forward direction |
| for (String namespaceKey : dmlp.getNamespaces().keySet()){ |
| for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){ |
| FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname); |
| refreshMemEstimates(fsblock); |
| } |
| } |
| |
| // handle statement blocks in "main" method |
| for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { |
| StatementBlock current = dmlp.getStatementBlock(i); |
| refreshMemEstimates(current); |
| } |
| } |
| |
| public void refreshMemEstimates(StatementBlock current) throws ParseException, HopsException { |
| |
| MemoTable memo = new MemoTable(); |
| ArrayList<Hop> hopsDAG = current.get_hops(); |
| if (hopsDAG != null && !hopsDAG.isEmpty()) { |
| Iterator<Hop> iter = hopsDAG.iterator(); |
| while (iter.hasNext()) { |
| iter.next().refreshMemEstimates(memo); |
| } |
| } |
| |
| if (current instanceof FunctionStatementBlock) { |
| |
| FunctionStatement fstmt = (FunctionStatement)current.getStatement(0); |
| for (StatementBlock sb : fstmt.getBody()){ |
| refreshMemEstimates(sb); |
| } |
| } |
| |
| if (current instanceof WhileStatementBlock) { |
| // handle predicate |
| WhileStatementBlock wstb = (WhileStatementBlock) current; |
| wstb.getPredicateHops().refreshMemEstimates(new MemoTable()); |
| |
| if (wstb.getNumStatements() > 1) |
| LOG.debug("While statement block has more than 1 stmt"); |
| WhileStatement ws = (WhileStatement)wstb.getStatement(0); |
| |
| for (StatementBlock sb : ws.getBody()){ |
| refreshMemEstimates(sb); |
| } |
| } |
| |
| if (current instanceof IfStatementBlock) { |
| // handle predicate |
| IfStatementBlock istb = (IfStatementBlock) current; |
| istb.getPredicateHops().refreshMemEstimates(new MemoTable()); |
| |
| if (istb.getNumStatements() > 1) |
| LOG.debug("If statement block has more than 1 stmt"); |
| IfStatement is = (IfStatement)istb.getStatement(0); |
| |
| for (StatementBlock sb : is.getIfBody()){ |
| refreshMemEstimates(sb); |
| } |
| for (StatementBlock sb : is.getElseBody()){ |
| refreshMemEstimates(sb); |
| } |
| } |
| |
| if (current instanceof ForStatementBlock) { |
| // handle predicate |
| ForStatementBlock fsb = (ForStatementBlock) current; |
| if (fsb.getFromHops() != null) |
| fsb.getFromHops().refreshMemEstimates(new MemoTable()); |
| if (fsb.getToHops() != null) |
| fsb.getToHops().refreshMemEstimates(new MemoTable()); |
| if (fsb.getIncrementHops() != null) |
| fsb.getIncrementHops().refreshMemEstimates(new MemoTable()); |
| |
| if (fsb.getNumStatements() > 1) |
| LOG.debug("For statement block has more than 1 stmt"); |
| ForStatement ws = (ForStatement)fsb.getStatement(0); |
| |
| for (StatementBlock sb : ws.getBody()){ |
| refreshMemEstimates(sb); |
| } |
| } |
| } |
| |
| public static void resetHopsDAGVisitStatus(DMLProgram dmlp) throws ParseException, LanguageException, HopsException { |
| |
| // for each namespace, handle function program blocks -- forward direction |
| for (String namespaceKey : dmlp.getNamespaces().keySet()){ |
| for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){ |
| FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname); |
| resetHopsDAGVisitStatus(fsblock); |
| } |
| } |
| |
| // handle statement blocks in "main" method |
| for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { |
| StatementBlock current = dmlp.getStatementBlock(i); |
| resetHopsDAGVisitStatus(current); |
| } |
| } |
| |
| public static void resetHopsDAGVisitStatus(StatementBlock current) throws ParseException, HopsException { |
| |
| ArrayList<Hop> hopsDAG = current.get_hops(); |
| if (hopsDAG != null && !hopsDAG.isEmpty() ) { |
| Hop.resetVisitStatus(hopsDAG); |
| } |
| |
| if (current instanceof FunctionStatementBlock) { |
| |
| FunctionStatement fstmt = (FunctionStatement)current.getStatement(0); |
| for (StatementBlock sb : fstmt.getBody()){ |
| resetHopsDAGVisitStatus(sb); |
| } |
| } |
| |
| if (current instanceof WhileStatementBlock) { |
| // handle predicate |
| WhileStatementBlock wstb = (WhileStatementBlock) current; |
| wstb.getPredicateHops().resetVisitStatus(); |
| |
| if (wstb.getNumStatements() > 1) |
| LOG.debug("While stmt block has more than 1 stmt"); |
| WhileStatement ws = (WhileStatement)wstb.getStatement(0); |
| |
| for (StatementBlock sb : ws.getBody()){ |
| resetHopsDAGVisitStatus(sb); |
| } |
| } |
| |
| if (current instanceof IfStatementBlock) { |
| // handle predicate |
| IfStatementBlock istb = (IfStatementBlock) current; |
| istb.getPredicateHops().resetVisitStatus(); |
| |
| if (istb.getNumStatements() > 1) |
| LOG.debug("If statement block has more than 1 stmt"); |
| IfStatement is = (IfStatement)istb.getStatement(0); |
| |
| for (StatementBlock sb : is.getIfBody()){ |
| resetHopsDAGVisitStatus(sb); |
| } |
| for (StatementBlock sb : is.getElseBody()){ |
| resetHopsDAGVisitStatus(sb); |
| } |
| } |
| |
| if (current instanceof ForStatementBlock) { |
| // handle predicate |
| ForStatementBlock fsb = (ForStatementBlock) current; |
| if (fsb.getFromHops() != null) |
| fsb.getFromHops().resetVisitStatus(); |
| if (fsb.getToHops() != null) |
| fsb.getToHops().resetVisitStatus(); |
| if (fsb.getIncrementHops() != null) |
| fsb.getIncrementHops().resetVisitStatus(); |
| |
| if (fsb.getNumStatements() > 1) |
| LOG.debug("For statment block has more than 1 stmt"); |
| ForStatement ws = (ForStatement)fsb.getStatement(0); |
| |
| for (StatementBlock sb : ws.getBody()){ |
| resetHopsDAGVisitStatus(sb); |
| } |
| } |
| } |
| |
| public void resetLopsDAGVisitStatus(DMLProgram dmlp) throws HopsException, LanguageException { |
| |
| // for each namespace, handle function program blocks |
| for (String namespaceKey : dmlp.getNamespaces().keySet()){ |
| for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){ |
| FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname); |
| resetLopsDAGVisitStatus(fsblock); |
| } |
| } |
| |
| for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { |
| StatementBlock current = dmlp.getStatementBlock(i); |
| resetLopsDAGVisitStatus(current); |
| } |
| } |
| |
| public void resetLopsDAGVisitStatus(StatementBlock current) throws HopsException { |
| |
| ArrayList<Hop> hopsDAG = current.get_hops(); |
| |
| if (hopsDAG != null && !hopsDAG.isEmpty() ) { |
| Iterator<Hop> iter = hopsDAG.iterator(); |
| while (iter.hasNext()){ |
| Hop currentHop = iter.next(); |
| currentHop.getLops().resetVisitStatus(); |
| } |
| } |
| |
| if (current instanceof FunctionStatementBlock) { |
| FunctionStatementBlock fsb = (FunctionStatementBlock) current; |
| FunctionStatement fs = (FunctionStatement)fsb.getStatement(0); |
| |
| for (StatementBlock sb : fs.getBody()){ |
| resetLopsDAGVisitStatus(sb); |
| } |
| } |
| |
| |
| if (current instanceof WhileStatementBlock) { |
| WhileStatementBlock wstb = (WhileStatementBlock) current; |
| wstb.get_predicateLops().resetVisitStatus(); |
| if (wstb.getNumStatements() > 1) |
| LOG.debug("While statement block has more than 1 stmt"); |
| WhileStatement ws = (WhileStatement)wstb.getStatement(0); |
| |
| for (StatementBlock sb : ws.getBody()){ |
| resetLopsDAGVisitStatus(sb); |
| } |
| } |
| |
| if (current instanceof IfStatementBlock) { |
| IfStatementBlock istb = (IfStatementBlock) current; |
| istb.get_predicateLops().resetVisitStatus(); |
| if (istb.getNumStatements() > 1) |
| LOG.debug("If statement block has more than 1 stmt"); |
| IfStatement is = (IfStatement)istb.getStatement(0); |
| |
| for (StatementBlock sb : is.getIfBody()){ |
| resetLopsDAGVisitStatus(sb); |
| } |
| |
| for (StatementBlock sb : is.getElseBody()){ |
| resetLopsDAGVisitStatus(sb); |
| } |
| } |
| |
| if (current instanceof ForStatementBlock) { |
| ForStatementBlock fsb = (ForStatementBlock) current; |
| |
| if (fsb.getFromLops() != null) |
| fsb.getFromLops().resetVisitStatus(); |
| if (fsb.getToLops() != null) |
| fsb.getToLops().resetVisitStatus(); |
| if (fsb.getIncrementLops() != null) |
| fsb.getIncrementLops().resetVisitStatus(); |
| |
| if (fsb.getNumStatements() > 1) |
| LOG.debug("For statement block has more than 1 stmt"); |
| ForStatement ws = (ForStatement)fsb.getStatement(0); |
| |
| for (StatementBlock sb : ws.getBody()){ |
| resetLopsDAGVisitStatus(sb); |
| } |
| } |
| } |
| |
| |
| public void constructHops(StatementBlock sb) |
| throws ParseException, LanguageException { |
| |
| if (sb instanceof WhileStatementBlock) { |
| constructHopsForWhileControlBlock((WhileStatementBlock) sb); |
| return; |
| } |
| |
| if (sb instanceof IfStatementBlock) { |
| constructHopsForIfControlBlock((IfStatementBlock) sb); |
| return; |
| } |
| |
| if (sb instanceof ForStatementBlock) { //NOTE: applies to ForStatementBlock and ParForStatementBlock |
| constructHopsForForControlBlock((ForStatementBlock) sb); |
| return; |
| } |
| |
| if (sb instanceof FunctionStatementBlock) { |
| constructHopsForFunctionControlBlock((FunctionStatementBlock) sb); |
| return; |
| } |
| |
| |
| HashMap<String, Hop> ids = new HashMap<String, Hop>(); |
| ArrayList<Hop> output = new ArrayList<Hop>(); |
| |
| VariableSet liveIn = sb.liveIn(); |
| VariableSet liveOut = sb.liveOut(); |
| VariableSet updated = sb._updated; |
| VariableSet gen = sb._gen; |
| VariableSet updatedLiveOut = new VariableSet(); |
| |
| // handle liveout variables that are updated --> target identifiers for Assignment |
| HashMap<String, Integer> liveOutToTemp = new HashMap<String, Integer>(); |
| for (int i = 0; i < sb.getNumStatements(); i++) { |
| Statement current = sb.getStatement(i); |
| |
| if (current instanceof AssignmentStatement) { |
| AssignmentStatement as = (AssignmentStatement) current; |
| DataIdentifier target = as.getTarget(); |
| if (liveOut.containsVariable(target.getName())) { |
| liveOutToTemp.put(target.getName(), Integer.valueOf(i)); |
| } |
| } |
| if (current instanceof MultiAssignmentStatement) { |
| MultiAssignmentStatement mas = (MultiAssignmentStatement) current; |
| |
| for (DataIdentifier target : mas.getTargetList()){ |
| if (liveOut.containsVariable(target.getName())) { |
| liveOutToTemp.put(target.getName(), Integer.valueOf(i)); |
| } |
| } |
| } |
| } |
| |
| // only create transient read operations for variables either updated or read-before-update |
| // (i.e., from LV analysis, updated and gen sets) |
| if ( !liveIn.getVariables().values().isEmpty() ) { |
| |
| for (String varName : liveIn.getVariables().keySet()) { |
| |
| if (updated.containsVariable(varName) || gen.containsVariable(varName)){ |
| |
| DataIdentifier var = liveIn.getVariables().get(varName); |
| long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1(); |
| long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2(); |
| DataOp read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD, null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock()); |
| read.setAllPositions(var.getBeginLine(), var.getBeginColumn(), var.getEndLine(), var.getEndColumn()); |
| ids.put(varName, read); |
| } |
| } |
| } |
| |
| |
| for( int i = 0; i < sb.getNumStatements(); i++ ) { |
| Statement current = sb.getStatement(i); |
| |
| if (current instanceof OutputStatement) { |
| OutputStatement os = (OutputStatement) current; |
| |
| DataExpression source = os.getSource(); |
| DataIdentifier target = os.getIdentifier(); |
| |
| //error handling unsupported indexing expression in write statement |
| if( target instanceof IndexedIdentifier ) { |
| throw new LanguageException(source.printErrorLocation()+": Unsupported indexing expression in write statement. " + |
| "Please, assign the right indexing result to a variable and write this variable."); |
| } |
| |
| DataOp ae = (DataOp)processExpression(source, target, ids); |
| String formatName = os.getExprParam(DataExpression.FORMAT_TYPE).toString(); |
| ae.setInputFormatType(Expression.convertFormatType(formatName)); |
| |
| if (ae.getDataType() == DataType.SCALAR ) { |
| ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1, -1); |
| } |
| else { |
| switch(ae.getInputFormatType()) { |
| case TEXT: |
| case MM: |
| case CSV: |
| // write output in textcell format |
| ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1, -1); |
| break; |
| |
| case BINARY: |
| // write output in binary block format |
| ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()); |
| break; |
| |
| default: |
| throw new LanguageException("Unrecognized file format: " + ae.getInputFormatType()); |
| } |
| } |
| |
| output.add(ae); |
| |
| } |
| |
| if (current instanceof PrintStatement) { |
| PrintStatement ps = (PrintStatement) current; |
| Expression source = ps.getExpression(); |
| PRINTTYPE ptype = ps.getType(); |
| |
| DataIdentifier target = createTarget(); |
| target.setDataType(DataType.SCALAR); |
| target.setValueType(ValueType.STRING); |
| target.setAllPositions(current.getFilename(), current.getBeginLine(), target.getBeginColumn(), current.getEndLine(), current.getEndColumn()); |
| |
| Hop ae = processExpression(source, target, ids); |
| |
| try { |
| Hop.OpOp1 op = (ptype == PRINTTYPE.PRINT ? Hop.OpOp1.PRINT : Hop.OpOp1.STOP); |
| Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae); |
| printHop.setAllPositions(current.getBeginLine(), current.getBeginColumn(), current.getEndLine(), current.getEndColumn()); |
| output.add(printHop); |
| } catch ( HopsException e ) { |
| throw new LanguageException(e); |
| } |
| } |
| |
| if (current instanceof AssignmentStatement) { |
| |
| AssignmentStatement as = (AssignmentStatement) current; |
| DataIdentifier target = as.getTarget(); |
| Expression source = as.getSource(); |
| |
| |
| // CASE: regular assignment statement -- source is DML expression that is NOT user-defined or external function |
| if (!(source instanceof FunctionCallIdentifier)){ |
| |
| // CASE: target is regular data identifier |
| if (!(target instanceof IndexedIdentifier)) { |
| |
| Hop ae = processExpression(source, target, ids); |
| ids.put(target.getName(), ae); |
| target.setProperties(source.getOutput()); |
| Integer statementId = liveOutToTemp.get(target.getName()); |
| if ((statementId != null) && (statementId.intValue() == i)) { |
| DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null); |
| transientwrite.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ae.getRowsInBlock(), ae.getColsInBlock()); |
| transientwrite.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndLine()); |
| updatedLiveOut.addVariable(target.getName(), target); |
| output.add(transientwrite); |
| } |
| } // end if (!(target instanceof IndexedIdentifier)) { |
| |
| // CASE: target is indexed identifier (left-hand side indexed expression) |
| else { |
| Hop ae = processLeftIndexedExpression(source, (IndexedIdentifier)target, ids); |
| |
| ids.put(target.getName(), ae); |
| |
| // obtain origDim values BEFORE they are potentially updated during setProperties call |
| // (this is incorrect for LHS Indexing) |
| long origDim1 = ((IndexedIdentifier)target).getOrigDim1(); |
| long origDim2 = ((IndexedIdentifier)target).getOrigDim2(); |
| target.setProperties(source.getOutput()); |
| ((IndexedIdentifier)target).setOriginalDimensions(origDim1, origDim2); |
| |
| // preserve data type matrix of any index identifier |
| // (required for scalar input to left indexing) |
| if( target.getDataType() != DataType.MATRIX ) { |
| target.setDataType(DataType.MATRIX); |
| target.setValueType(ValueType.DOUBLE); |
| target.setBlockDimensions(ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()); |
| } |
| |
| Integer statementId = liveOutToTemp.get(target.getName()); |
| if ((statementId != null) && (statementId.intValue() == i)) { |
| DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null); |
| transientwrite.setOutputParams(origDim1, origDim2, ae.getNnz(), ae.getUpdateType(), ae.getRowsInBlock(), ae.getColsInBlock()); |
| transientwrite.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndColumn()); |
| updatedLiveOut.addVariable(target.getName(), target); |
| output.add(transientwrite); |
| } |
| } |
| |
| |
| } |
| else |
| { |
| //assignment, function call |
| FunctionCallIdentifier fci = (FunctionCallIdentifier) source; |
| FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName()); |
| |
| //error handling missing function |
| if (fsb == null){ |
| String error = source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace(); |
| LOG.error(error); |
| throw new LanguageException(error); |
| } |
| |
| //error handling unsupported function call in indexing expression |
| if( target instanceof IndexedIdentifier ) { |
| String fkey = DMLProgram.constructFunctionKey(fci.getNamespace(),fci.getName()); |
| throw new LanguageException("Unsupported function call to '"+fkey+"' in left indexing expression. " + |
| "Please, assign the function output to a variable."); |
| } |
| |
| ArrayList<Hop> finputs = new ArrayList<Hop>(); |
| for (ParameterExpression paramName : fci.getParamExprs()){ |
| Hop in = processExpression(paramName.getExpr(), null, ids); |
| finputs.add(in); |
| } |
| |
| //create function op |
| FunctionType ftype = fsb.getFunctionOpType(); |
| FunctionOp fcall = new FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, new String[]{target.getName()}); |
| output.add(fcall); |
| |
| //TODO function output dataops (phase 3) |
| //DataOp trFoutput = new DataOp(target.getName(), target.getDataType(), target.getValueType(), fcall, DataOpTypes.FUNCTIONOUTPUT, null); |
| //DataOp twFoutput = new DataOp(target.getName(), target.getDataType(), target.getValueType(), trFoutput, DataOpTypes.TRANSIENTWRITE, null); |
| } |
| } |
| |
| else if (current instanceof MultiAssignmentStatement) { |
| //multi-assignment, by definition a function call |
| MultiAssignmentStatement mas = (MultiAssignmentStatement) current; |
| Expression source = mas.getSource(); |
| |
| if ( source instanceof FunctionCallIdentifier ) { |
| FunctionCallIdentifier fci = (FunctionCallIdentifier) source; |
| FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName()); |
| FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); |
| if (fstmt == null){ |
| LOG.error(source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace()); |
| throw new LanguageException(source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace()); |
| } |
| |
| ArrayList<Hop> finputs = new ArrayList<Hop>(); |
| for (ParameterExpression paramName : fci.getParamExprs()){ |
| Hop in = processExpression(paramName.getExpr(), null, ids); |
| finputs.add(in); |
| } |
| |
| //create function op |
| String[] foutputs = new String[mas.getTargetList().size()]; |
| int count = 0; |
| for ( DataIdentifier paramName : mas.getTargetList() ){ |
| foutputs[count++]=paramName.getName(); |
| } |
| |
| FunctionType ftype = fsb.getFunctionOpType(); |
| FunctionOp fcall = new FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, foutputs); |
| output.add(fcall); |
| |
| //TODO function output dataops (phase 3) |
| /*for ( DataIdentifier paramName : mas.getTargetList() ){ |
| DataOp twFoutput = new DataOp(paramName.getName(), paramName.getDataType(), paramName.getValueType(), fcall, DataOpTypes.TRANSIENTWRITE, null); |
| output.add(twFoutput); |
| }*/ |
| } |
| else if ( source instanceof BuiltinFunctionExpression && ((BuiltinFunctionExpression)source).multipleReturns() ) { |
| // construct input hops |
| Hop fcall = processMultipleReturnBuiltinFunctionExpression((BuiltinFunctionExpression)source, mas.getTargetList(), ids); |
| output.add(fcall); |
| } |
| else if ( source instanceof ParameterizedBuiltinFunctionExpression && ((ParameterizedBuiltinFunctionExpression)source).multipleReturns() ) { |
| // construct input hops |
| Hop fcall = processMultipleReturnParameterizedBuiltinFunctionExpression((ParameterizedBuiltinFunctionExpression)source, mas.getTargetList(), ids); |
| output.add(fcall); |
| } |
| else |
| throw new LanguageException("Class \"" + source.getClass() + "\" is not supported in Multiple Assignment statements"); |
| } |
| |
| } |
| sb.updateLiveVariablesOut(updatedLiveOut); |
| sb.set_hops(output); |
| |
| } |
| |
| public void constructHopsForIfControlBlock(IfStatementBlock sb) throws ParseException, LanguageException { |
| |
| IfStatement ifsb = (IfStatement) sb.getStatement(0); |
| ArrayList<StatementBlock> ifBody = ifsb.getIfBody(); |
| ArrayList<StatementBlock> elseBody = ifsb.getElseBody(); |
| |
| // construct hops for predicate in if statement |
| constructHopsForConditionalPredicate(sb); |
| |
| // handle if statement body |
| for( StatementBlock current : ifBody ) { |
| constructHops(current); |
| } |
| |
| // handle else stmt body |
| for( StatementBlock current : elseBody ) { |
| constructHops(current); |
| } |
| } |
| |
| /** |
| * Constructs Hops for a given ForStatementBlock or ParForStatementBlock, respectively. |
| * |
| * @param sb |
| * @throws ParseException |
| * @throws LanguageException |
| */ |
| public void constructHopsForForControlBlock(ForStatementBlock sb) |
| throws ParseException, LanguageException |
| { |
| |
| ForStatement fs = (ForStatement) sb.getStatement(0); |
| ArrayList<StatementBlock> body = fs.getBody(); |
| |
| // construct hops for iterable predicate |
| constructHopsForIterablePredicate(sb); |
| |
| for( StatementBlock current : body ) { |
| constructHops(current); |
| } |
| } |
| |
| public void constructHopsForFunctionControlBlock(FunctionStatementBlock fsb) throws ParseException, LanguageException { |
| |
| ArrayList<StatementBlock> body = ((FunctionStatement)fsb.getStatement(0)).getBody(); |
| |
| for( StatementBlock current : body ) { |
| constructHops(current); |
| } |
| } |
| |
| public void constructHopsForWhileControlBlock(WhileStatementBlock sb) |
| throws ParseException, LanguageException { |
| |
| ArrayList<StatementBlock> body = ((WhileStatement)sb.getStatement(0)).getBody(); |
| |
| // construct hops for while predicate |
| constructHopsForConditionalPredicate(sb); |
| |
| for( StatementBlock current : body ) { |
| constructHops(current); |
| } |
| } |
| |
| |
| public void constructHopsForConditionalPredicate(StatementBlock passedSB) throws ParseException { |
| |
| HashMap<String, Hop> _ids = new HashMap<String, Hop>(); |
| |
| // set conditional predicate |
| ConditionalPredicate cp = null; |
| |
| if (passedSB instanceof WhileStatementBlock){ |
| WhileStatement ws = (WhileStatement) ((WhileStatementBlock)passedSB).getStatement(0); |
| cp = ws.getConditionalPredicate(); |
| } |
| else if (passedSB instanceof IfStatementBlock) { |
| IfStatement ws = (IfStatement) ((IfStatementBlock)passedSB).getStatement(0); |
| cp = ws.getConditionalPredicate(); |
| } |
| else { |
| throw new ParseException("ConditionalPredicate expected only for while or if statements."); |
| } |
| |
| VariableSet varsRead = cp.variablesRead(); |
| |
| for (String varName : varsRead.getVariables().keySet()) { |
| |
| // creating transient read for live in variables |
| DataIdentifier var = passedSB.liveIn().getVariables().get(varName); |
| |
| DataOp read = null; |
| |
| if (var == null) { |
| LOG.error("variable " + varName + " not live variable for conditional predicate"); |
| throw new ParseException("variable " + varName + " not live variable for conditional predicate"); |
| } else { |
| long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1(); |
| long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2(); |
| |
| read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD, |
| null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock()); |
| read.setAllPositions(var.getBeginLine(), var.getBeginColumn(), var.getEndLine(), var.getEndColumn()); |
| } |
| _ids.put(varName, read); |
| } |
| |
| DataIdentifier target = new DataIdentifier(Expression.getTempName()); |
| target.setDataType(DataType.SCALAR); |
| target.setValueType(ValueType.BOOLEAN); |
| target.setAllPositions(passedSB.getFilename(), passedSB.getBeginLine(), passedSB.getBeginColumn(), passedSB.getEndLine(), passedSB.getEndColumn()); |
| Hop predicateHops = null; |
| Expression predicate = cp.getPredicate(); |
| |
| if (predicate instanceof RelationalExpression) { |
| predicateHops = processRelationalExpression((RelationalExpression) cp.getPredicate(), target, _ids); |
| } else if (predicate instanceof BooleanExpression) { |
| predicateHops = processBooleanExpression((BooleanExpression) cp.getPredicate(), target, _ids); |
| } else if (predicate instanceof DataIdentifier) { |
| // handle data identifier predicate |
| predicateHops = processExpression(cp.getPredicate(), null, _ids); |
| } else if (predicate instanceof ConstIdentifier) { |
| // handle constant identifier |
| // a) translate 0 --> FALSE; translate 1 --> TRUE |
| // b) disallow string values |
| if ( (predicate instanceof IntIdentifier && ((IntIdentifier)predicate).getValue() == 0) || (predicate instanceof DoubleIdentifier && ((DoubleIdentifier)predicate).getValue() == 0.0)) { |
| cp.setPredicate(new BooleanIdentifier(false, |
| predicate.getFilename(), |
| predicate.getBeginLine(), predicate.getBeginColumn(), |
| predicate.getEndLine(), predicate.getEndColumn())); |
| |
| } |
| else if ( (predicate instanceof IntIdentifier && ((IntIdentifier)predicate).getValue() == 1) || (predicate instanceof DoubleIdentifier && ((DoubleIdentifier)predicate).getValue() == 1.0)) { |
| cp.setPredicate(new BooleanIdentifier(true, |
| predicate.getFilename(), |
| predicate.getBeginLine(), predicate.getBeginColumn(), |
| predicate.getEndLine(), predicate.getEndColumn())); |
| } |
| else if (predicate instanceof IntIdentifier || predicate instanceof DoubleIdentifier){ |
| cp.setPredicate(new BooleanIdentifier(true, |
| predicate.getFilename(), |
| predicate.getBeginLine(), predicate.getBeginColumn(), |
| predicate.getEndLine(), predicate.getEndColumn())); |
| LOG.warn(predicate.printWarningLocation() + "Numerical value '" + predicate.toString() + "' (!= 0/1) is converted to boolean TRUE by DML"); |
| } |
| else if (predicate instanceof StringIdentifier) { |
| LOG.error(predicate.printErrorLocation() + "String value '" + predicate.toString() + "' is not allowed for iterable predicate"); |
| throw new ParseException(predicate.printErrorLocation() + "String value '" + predicate.toString() + "' is not allowed for iterable predicate"); |
| |
| } |
| predicateHops = processExpression(cp.getPredicate(), null, _ids); |
| } |
| if (passedSB instanceof WhileStatementBlock) |
| ((WhileStatementBlock)passedSB).setPredicateHops(predicateHops); |
| else if (passedSB instanceof IfStatementBlock) |
| ((IfStatementBlock)passedSB).setPredicateHops(predicateHops); |
| } |
| |
| |
| /** |
| * Constructs all predicate Hops (for FROM, TO, INCREMENT) of an iterable predicate |
| * and assigns these Hops to the passed statement block. |
| * |
| * Method used for both ForStatementBlock and ParForStatementBlock. |
| * |
| * @param passedSB |
| * @throws ParseException |
| */ |
| public void constructHopsForIterablePredicate(ForStatementBlock fsb) |
| throws ParseException |
| { |
| HashMap<String, Hop> _ids = new HashMap<String, Hop>(); |
| |
| // set iterable predicate |
| ForStatement fs = (ForStatement) fsb.getStatement(0); |
| IterablePredicate ip = fs.getIterablePredicate(); |
| |
| for(int i=0; i < 3; i++) { |
| VariableSet varsRead = null; |
| if (i==0) |
| varsRead = ip.getFromExpr().variablesRead(); |
| else if (i==1) |
| varsRead = ip.getToExpr().variablesRead(); |
| else if( ip.getIncrementExpr() != null ) |
| varsRead = ip.getIncrementExpr().variablesRead(); |
| |
| if(varsRead != null) { |
| for (String varName : varsRead.getVariables().keySet()) { |
| |
| DataIdentifier var = fsb.liveIn().getVariable(varName); |
| DataOp read = null; |
| if (var == null) { |
| LOG.error("variable '" + varName + "' is not available for iterable predicate"); |
| throw new ParseException("variable '" + varName + "' is not available for iterable predicate"); |
| } |
| else { |
| long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1(); |
| long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2(); |
| read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD, |
| null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock()); |
| read.setAllPositions(var.getBeginLine(), var.getBeginColumn(), var.getEndLine(), var.getEndColumn()); |
| } |
| _ids.put(varName, read); |
| } |
| } |
| |
| //construct hops for from, to, and increment expressions |
| if(i==0) |
| fsb.setFromHops( processTempIntExpression( ip.getFromExpr(), _ids )); |
| else if(i==1) |
| fsb.setToHops( processTempIntExpression( ip.getToExpr(), _ids )); |
| else if( ip.getIncrementExpr() != null ) |
| fsb.setIncrementHops( processTempIntExpression( ip.getIncrementExpr(), _ids )); |
| |
| } |
| |
| /*VariableSet varsRead = ip.variablesRead(); |
| |
| for (String varName : varsRead.getVariables().keySet()) { |
| |
| DataIdentifier var = passedSB.liveIn().getVariable(varName); |
| DataOp read = null; |
| if (var == null) { |
| LOG.error(var.printErrorLocation() + "variable '" + varName + "' is not available for iterable predicate"); |
| throw new ParseException(var.printErrorLocation() + "variable '" + varName + "' is not available for iterable predicate"); |
| } |
| else { |
| long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1(); |
| long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2(); |
| read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD, |
| null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock()); |
| read.setAllPositions(var.getBeginLine(), var.getBeginColumn(), var.getEndLine(), var.getEndColumn()); |
| } |
| _ids.put(varName, read); |
| } |
| |
| //construct hops for from, to, and increment expressions |
| fsb.setFromHops( processTempIntExpression( ip.getFromExpr(), _ids )); |
| fsb.setToHops( processTempIntExpression( ip.getToExpr(), _ids )); |
| fsb.setIncrementHops( processTempIntExpression( ip.getIncrementExpr(), _ids ));*/ |
| } |
| |
| |
| /** |
| * Construct Hops from parse tree : Process Expression in an assignment |
| * statement |
| * |
| * @throws ParseException |
| */ |
| private Hop processExpression(Expression source, DataIdentifier target, HashMap<String, Hop> hops) throws ParseException { |
| if (source.getKind() == Expression.Kind.BinaryOp) { |
| return processBinaryExpression((BinaryExpression) source, target, hops); |
| } else if (source.getKind() == Expression.Kind.RelationalOp) { |
| return processRelationalExpression((RelationalExpression) source, target, hops); |
| } else if (source.getKind() == Expression.Kind.BooleanOp) { |
| return processBooleanExpression((BooleanExpression) source, target, hops); |
| } else if (source.getKind() == Expression.Kind.Data) { |
| if (source instanceof IndexedIdentifier){ |
| IndexedIdentifier sourceIndexed = (IndexedIdentifier) source; |
| return processIndexingExpression(sourceIndexed,target,hops); |
| } else if (source instanceof IntIdentifier) { |
| IntIdentifier sourceInt = (IntIdentifier) source; |
| LiteralOp litop = new LiteralOp(sourceInt.getValue()); |
| litop.setAllPositions(sourceInt.getBeginLine(), sourceInt.getBeginColumn(), sourceInt.getEndLine(), sourceInt.getEndColumn()); |
| setIdentifierParams(litop, sourceInt); |
| return litop; |
| } else if (source instanceof DoubleIdentifier) { |
| DoubleIdentifier sourceDouble = (DoubleIdentifier) source; |
| LiteralOp litop = new LiteralOp(sourceDouble.getValue()); |
| litop.setAllPositions(sourceDouble.getBeginLine(), sourceDouble.getBeginColumn(), sourceDouble.getEndLine(), sourceDouble.getEndColumn()); |
| setIdentifierParams(litop, sourceDouble); |
| return litop; |
| } else if (source instanceof DataIdentifier) { |
| DataIdentifier sourceId = (DataIdentifier) source; |
| return hops.get(sourceId.getName()); |
| } else if (source instanceof BooleanIdentifier) { |
| BooleanIdentifier sourceBoolean = (BooleanIdentifier) source; |
| LiteralOp litop = new LiteralOp(sourceBoolean.getValue()); |
| litop.setAllPositions(sourceBoolean.getBeginLine(), sourceBoolean.getBeginColumn(), sourceBoolean.getEndLine(), sourceBoolean.getEndColumn()); |
| setIdentifierParams(litop, sourceBoolean); |
| return litop; |
| } else if (source instanceof StringIdentifier) { |
| StringIdentifier sourceString = (StringIdentifier) source; |
| LiteralOp litop = new LiteralOp(sourceString.getValue()); |
| litop.setAllPositions(sourceString.getBeginLine(), sourceString.getBeginColumn(), sourceString.getEndLine(), sourceString.getEndColumn()); |
| setIdentifierParams(litop, sourceString); |
| return litop; |
| } |
| } else if (source.getKind() == Expression.Kind.BuiltinFunctionOp) { |
| try { |
| return processBuiltinFunctionExpression((BuiltinFunctionExpression) source, target, hops); |
| } catch (HopsException e) { |
| throw new ParseException(e.getMessage()); |
| } |
| } else if (source.getKind() == Expression.Kind.ParameterizedBuiltinFunctionOp ) { |
| try { |
| return processParameterizedBuiltinFunctionExpression((ParameterizedBuiltinFunctionExpression)source, target, hops); |
| } catch ( HopsException e ) { |
| throw new ParseException(e.getMessage()); |
| } |
| } else if (source.getKind() == Expression.Kind.DataOp ) { |
| try { |
| Hop ae = (Hop)processDataExpression((DataExpression)source, target, hops); |
| |
| if (ae instanceof DataOp){ |
| String formatName = ((DataExpression)source).getVarParam(DataExpression.FORMAT_TYPE).toString(); |
| ((DataOp)ae).setInputFormatType(Expression.convertFormatType(formatName)); |
| } |
| //hops.put(target.getName(), ae); |
| return ae; |
| } catch ( Exception e ) { |
| throw new ParseException(e.getMessage()); |
| } |
| } |
| return null; |
| } // end method processExpression |
| |
| private DataIdentifier createTarget(Expression source) { |
| Identifier id = source.getOutput(); |
| if (id instanceof DataIdentifier && !(id instanceof DataExpression)) |
| return (DataIdentifier) id; |
| DataIdentifier target = new DataIdentifier(Expression.getTempName()); |
| target.setProperties(id); |
| return target; |
| } |
| |
| private DataIdentifier createTarget() { |
| DataIdentifier target = new DataIdentifier(Expression.getTempName()); |
| return target; |
| } |
| |
| |
| /** |
| * Constructs the Hops for arbitrary expressions that eventually evaluate to an INT scalar. |
| * |
| * @param source |
| * @param hops |
| * @return |
| * @throws ParseException |
| */ |
| private Hop processTempIntExpression( Expression source, HashMap<String, Hop> hops ) |
| throws ParseException |
| { |
| DataIdentifier tmpOut = createTarget(); |
| tmpOut.setDataType(DataType.SCALAR); |
| tmpOut.setValueType(ValueType.INT); |
| source.setOutput(tmpOut); |
| |
| return processExpression(source, tmpOut, hops ); |
| } |
| |
| private Hop processLeftIndexedExpression(Expression source, IndexedIdentifier target, HashMap<String, Hop> hops) |
| throws ParseException { |
| |
| // process target indexed expressions |
| Hop rowLowerHops = null, rowUpperHops = null, colLowerHops = null, colUpperHops = null; |
| |
| if (target.getRowLowerBound() != null) |
| rowLowerHops = processExpression(target.getRowLowerBound(),null,hops); |
| else |
| rowLowerHops = new LiteralOp(1); |
| |
| if (target.getRowUpperBound() != null) |
| rowUpperHops = processExpression(target.getRowUpperBound(),null,hops); |
| else |
| { |
| if ( target.getDim1() != -1 ) |
| rowUpperHops = new LiteralOp(target.getOrigDim1()); |
| else |
| { |
| try { |
| //currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hops.OpOp1.NROW, expr); |
| rowUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NROW, hops.get(target.getName())); |
| rowUpperHops.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndColumn()); |
| } catch (HopsException e) { |
| LOG.error(target.printErrorLocation() + "error processing row upper index for indexed expression " + target.toString()); |
| throw new RuntimeException(target.printErrorLocation() + "error processing row upper index for indexed expression " + target.toString()); |
| } |
| } |
| } |
| if (target.getColLowerBound() != null) |
| colLowerHops = processExpression(target.getColLowerBound(),null,hops); |
| else |
| colLowerHops = new LiteralOp(1); |
| |
| if (target.getColUpperBound() != null) |
| colUpperHops = processExpression(target.getColUpperBound(),null,hops); |
| else |
| { |
| if ( target.getDim2() != -1 ) |
| colUpperHops = new LiteralOp(target.getOrigDim2()); |
| else |
| { |
| try { |
| colUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NCOL, hops.get(target.getName())); |
| } catch (HopsException e) { |
| LOG.error(target.printErrorLocation() + " error processing column upper index for indexed expression " + target.toString()); |
| throw new RuntimeException(target.printErrorLocation() + " error processing column upper index for indexed expression " + target.toString(), e); |
| } |
| } |
| } |
| |
| //if (target == null) { |
| // target = createTarget(source); |
| //} |
| |
| // process the source expression to get source Hops |
| Hop sourceOp = processExpression(source, target, hops); |
| |
| // process the target to get targetHops |
| Hop targetOp = hops.get(target.getName()); |
| if (targetOp == null){ |
| LOG.error(target.printErrorLocation() + " must define matrix " + target.getName() + " before indexing operations are allowed "); |
| throw new ParseException(target.printErrorLocation() + " must define matrix " + target.getName() + " before indexing operations are allowed "); |
| } |
| |
| //TODO Doug, please verify this (we need probably a cleaner way than this postprocessing) |
| if( sourceOp.getDataType() == DataType.MATRIX && source.getOutput().getDataType() == DataType.SCALAR ) |
| sourceOp.setDataType(DataType.SCALAR); |
| |
| Hop leftIndexOp = new LeftIndexingOp(target.getName(), target.getDataType(), target.getValueType(), |
| targetOp, sourceOp, rowLowerHops, rowUpperHops, colLowerHops, colUpperHops, |
| target.getRowLowerEqualsUpper(), target.getColLowerEqualsUpper()); |
| |
| setIdentifierParams(leftIndexOp, target); |
| |
| leftIndexOp.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndColumn()); |
| leftIndexOp.setDim1(target.getOrigDim1()); |
| leftIndexOp.setDim2(target.getOrigDim2()); |
| |
| return leftIndexOp; |
| } |
| |
| |
| private Hop processIndexingExpression(IndexedIdentifier source, DataIdentifier target, HashMap<String, Hop> hops) |
| throws ParseException { |
| |
| // process Hops for indexes (for source) |
| Hop rowLowerHops = null, rowUpperHops = null, colLowerHops = null, colUpperHops = null; |
| |
| if (source.getRowLowerBound() != null) |
| rowLowerHops = processExpression(source.getRowLowerBound(),null,hops); |
| else |
| rowLowerHops = new LiteralOp(1); |
| |
| if (source.getRowUpperBound() != null) |
| rowUpperHops = processExpression(source.getRowUpperBound(),null,hops); |
| else |
| { |
| if ( source.getOrigDim1() != -1 ) |
| rowUpperHops = new LiteralOp(source.getOrigDim1()); |
| else |
| { |
| try { |
| //currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hops.OpOp1.NROW, expr); |
| rowUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NROW, hops.get(source.getName())); |
| rowUpperHops.setAllPositions(source.getBeginLine(),source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| } catch (HopsException e) { |
| LOG.error(source.printErrorLocation() + "error processing row upper index for indexed identifier " + source.toString()); |
| throw new RuntimeException(source.printErrorLocation() + "error processing row upper index for indexed identifier " + source.toString() + e); |
| } |
| } |
| } |
| if (source.getColLowerBound() != null) |
| colLowerHops = processExpression(source.getColLowerBound(),null,hops); |
| else |
| colLowerHops = new LiteralOp(1); |
| |
| if (source.getColUpperBound() != null) |
| colUpperHops = processExpression(source.getColUpperBound(),null,hops); |
| else |
| { |
| if ( source.getOrigDim2() != -1 ) |
| colUpperHops = new LiteralOp(source.getOrigDim2()); |
| else |
| { |
| try { |
| colUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NCOL, hops.get(source.getName())); |
| } catch (HopsException e) { |
| LOG.error(source.printErrorLocation() + "error processing column upper index for indexed indentifier " + source.toString(), e); |
| throw new RuntimeException(source.printErrorLocation() + "error processing column upper index for indexed indentifier " + source.toString(), e); |
| } |
| } |
| } |
| |
| if (target == null) { |
| target = createTarget(source); |
| } |
| //unknown nnz after range indexing (applies to indexing op but also |
| //data dependent operations) |
| target.setNnz(-1); |
| |
| Hop indexOp = new IndexingOp(target.getName(), target.getDataType(), target.getValueType(), |
| hops.get(source.getName()), rowLowerHops, rowUpperHops, colLowerHops, colUpperHops, |
| source.getRowLowerEqualsUpper(), source.getColLowerEqualsUpper()); |
| |
| indexOp.setAllPositions(indexOp.getBeginLine(), indexOp.getBeginColumn(), indexOp.getEndLine(), indexOp.getEndColumn()); |
| setIdentifierParams(indexOp, target); |
| |
| return indexOp; |
| } |
| |
| |
| /** |
| * Construct Hops from parse tree : Process Binary Expression in an |
| * assignment statement |
| * |
| * @throws ParseException |
| */ |
| private Hop processBinaryExpression(BinaryExpression source, DataIdentifier target, HashMap<String, Hop> hops) |
| throws ParseException |
| { |
| Hop left = processExpression(source.getLeft(), null, hops); |
| Hop right = processExpression(source.getRight(), null, hops); |
| |
| if (left == null || right == null){ |
| left = processExpression(source.getLeft(), null, hops); |
| right = processExpression(source.getRight(), null, hops); |
| } |
| |
| Hop currBop = null; |
| |
| //prepare target identifier and ensure that output type is of inferred type |
| //(type should not be determined by target (e.g., string for print) |
| if (target == null) { |
| target = createTarget(source); |
| } |
| target.setValueType(source.getOutput().getValueType()); |
| |
| if (source.getOpCode() == Expression.BinaryOp.PLUS) { |
| currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.PLUS, left, right); |
| } else if (source.getOpCode() == Expression.BinaryOp.MINUS) { |
| currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MINUS, left, right); |
| } else if (source.getOpCode() == Expression.BinaryOp.MULT) { |
| currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, left, right); |
| } else if (source.getOpCode() == Expression.BinaryOp.DIV) { |
| currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.DIV, left, right); |
| } else if (source.getOpCode() == Expression.BinaryOp.MODULUS) { |
| currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MODULUS, left, right); |
| } else if (source.getOpCode() == Expression.BinaryOp.INTDIV) { |
| currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.INTDIV, left, right); |
| } else if (source.getOpCode() == Expression.BinaryOp.MATMULT) { |
| currBop = new AggBinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, AggOp.SUM, left, right); |
| } else if (source.getOpCode() == Expression.BinaryOp.POW) { |
| currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.POW, left, right); |
| } |
| else { |
| throw new ParseException("Unsupported parsing of binary expression: "+source.getOpCode()); |
| } |
| setIdentifierParams(currBop, source.getOutput()); |
| currBop.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| return currBop; |
| |
| } |
| |
| private Hop processRelationalExpression(RelationalExpression source, DataIdentifier target, |
| HashMap<String, Hop> hops) throws ParseException { |
| |
| Hop left = processExpression(source.getLeft(), null, hops); |
| Hop right = processExpression(source.getRight(), null, hops); |
| |
| Hop currBop = null; |
| |
| if (target == null) { |
| target = createTarget(source); |
| if(left.getDataType() == DataType.MATRIX || right.getDataType() == DataType.MATRIX) { |
| // Added to support matrix relational comparison |
| target.setDataType(DataType.MATRIX); |
| target.setValueType(ValueType.BOOLEAN); |
| } |
| else { |
| // Added to support scalar relational comparison |
| target.setDataType(DataType.SCALAR); |
| target.setValueType(ValueType.BOOLEAN); |
| } |
| } |
| |
| OpOp2 op = null; |
| |
| if (source.getOpCode() == Expression.RelationalOp.LESS) { |
| op = OpOp2.LESS; |
| } else if (source.getOpCode() == Expression.RelationalOp.LESSEQUAL) { |
| op = OpOp2.LESSEQUAL; |
| } else if (source.getOpCode() == Expression.RelationalOp.GREATER) { |
| op = OpOp2.GREATER; |
| } else if (source.getOpCode() == Expression.RelationalOp.GREATEREQUAL) { |
| op = OpOp2.GREATEREQUAL; |
| } else if (source.getOpCode() == Expression.RelationalOp.EQUAL) { |
| op = OpOp2.EQUAL; |
| } else if (source.getOpCode() == Expression.RelationalOp.NOTEQUAL) { |
| op = OpOp2.NOTEQUAL; |
| } |
| currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, left, right); |
| currBop.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| return currBop; |
| } |
| |
| /** |
| * |
| * @param source |
| * @param target |
| * @param hops |
| * @return |
| * @throws ParseException |
| */ |
| private Hop processBooleanExpression(BooleanExpression source, DataIdentifier target, HashMap<String, Hop> hops) |
| throws ParseException |
| { |
| // Boolean Not has a single parameter |
| boolean constLeft = (source.getLeft().getOutput() instanceof ConstIdentifier); |
| boolean constRight = false; |
| if (source.getRight() != null) { |
| constRight = (source.getRight().getOutput() instanceof ConstIdentifier); |
| } |
| |
| if (constLeft || constRight) { |
| LOG.error(source.printErrorLocation() + "Boolean expression with constant unsupported"); |
| throw new RuntimeException(source.printErrorLocation() + "Boolean expression with constant unsupported"); |
| } |
| |
| Hop left = processExpression(source.getLeft(), null, hops); |
| Hop right = null; |
| if (source.getRight() != null) { |
| right = processExpression(source.getRight(), null, hops); |
| } |
| |
| //prepare target identifier and ensure that output type is boolean |
| //(type should not be determined by target (e.g., string for print) |
| if (target == null) { |
| target = createTarget(source); |
| } |
| target.setValueType(ValueType.BOOLEAN); |
| |
| if (source.getRight() == null) { |
| Hop currUop = null; |
| try { |
| currUop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NOT, left); |
| currUop.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| } catch (HopsException e) { |
| throw new ParseException(e.getMessage()); |
| } |
| return currUop; |
| } else { |
| Hop currBop = null; |
| OpOp2 op = null; |
| |
| if (source.getOpCode() == Expression.BooleanOp.LOGICALAND) { |
| op = OpOp2.AND; |
| } else if (source.getOpCode() == Expression.BooleanOp.LOGICALOR) { |
| op = OpOp2.OR; |
| } else { |
| LOG.error(source.printErrorLocation() + "Unknown boolean operation " + source.getOpCode()); |
| throw new RuntimeException(source.printErrorLocation() + "Unknown boolean operation " + source.getOpCode()); |
| } |
| currBop = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, left, right); |
| currBop.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| // setIdentifierParams(currBop,source.getOutput()); |
| return currBop; |
| } |
| } |
| |
| private Hop constructDfHop(String name, DataType dt, ValueType vt, ParameterizedBuiltinFunctionOp op, HashMap<String,Hop> paramHops) throws HopsException { |
| |
| // Add a hop to paramHops to store distribution information. |
| // Distribution parameter hops would have been already present in paramHops. |
| Hop distLop = null; |
| switch(op) { |
| case QNORM: |
| case PNORM: |
| distLop = new LiteralOp("normal"); |
| break; |
| case QT: |
| case PT: |
| distLop = new LiteralOp("t"); |
| break; |
| case QF: |
| case PF: |
| distLop = new LiteralOp("f"); |
| break; |
| case QCHISQ: |
| case PCHISQ: |
| distLop = new LiteralOp("chisq"); |
| break; |
| case QEXP: |
| case PEXP: |
| distLop = new LiteralOp("exp"); |
| break; |
| |
| case CDF: |
| case INVCDF: |
| break; |
| |
| default: |
| throw new HopsException("Invalid operation: " + op); |
| } |
| if (distLop != null) |
| paramHops.put("dist", distLop); |
| |
| return new ParameterizedBuiltinOp(name, dt, vt, ParameterizedBuiltinFunctionExpression.pbHopMap.get(op), paramHops); |
| } |
| |
| /** |
| * |
| * @param source |
| * @param targetList |
| * @param hops |
| * @return |
| * @throws ParseException |
| */ |
| private Hop processMultipleReturnParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFunctionExpression source, ArrayList<DataIdentifier> targetList, |
| HashMap<String, Hop> hops) throws ParseException |
| { |
| FunctionType ftype = FunctionType.MULTIRETURN_BUILTIN; |
| String nameSpace = DMLProgram.INTERNAL_NAMESPACE; |
| |
| // Create an array list to hold the outputs of this lop. |
| // Exact list of outputs are added based on opcode. |
| ArrayList<Hop> outputs = new ArrayList<Hop>(); |
| |
| // Construct Hop for current builtin function expression based on its type |
| Hop currBuiltinOp = null; |
| switch (source.getOpCode()) { |
| case TRANSFORMENCODE: |
| ArrayList<Hop> inputs = new ArrayList<Hop>(); |
| inputs.add( processExpression(source.getVarParam("target"), null, hops) ); |
| inputs.add( processExpression(source.getVarParam("spec"), null, hops) ); |
| String[] outputNames = new String[targetList.size()]; |
| outputNames[0] = ((DataIdentifier)targetList.get(0)).getName(); |
| outputNames[1] = ((DataIdentifier)targetList.get(1)).getName(); |
| outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[0])); |
| outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[1])); |
| |
| currBuiltinOp = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), inputs, outputNames, outputs); |
| break; |
| |
| default: |
| throw new ParseException("Invaid Opcode in DMLTranslator:processMultipleReturnParameterizedBuiltinFunctionExpression(): " + source.getOpCode()); |
| } |
| |
| // set properties for created hops based on outputs of source expression |
| for ( int i=0; i < source.getOutputs().length; i++ ) { |
| setIdentifierParams( outputs.get(i), source.getOutputs()[i]); |
| outputs.get(i).setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| } |
| currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| |
| return currBuiltinOp; |
| } |
| |
| /** |
| * Construct Hops from parse tree : Process ParameterizedBuiltinFunction Expression in an |
| * assignment statement |
| * |
| * @throws ParseException |
| * @throws HopsException |
| */ |
| private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFunctionExpression source, DataIdentifier target, |
| HashMap<String, Hop> hops) throws ParseException, HopsException { |
| |
| // this expression has multiple "named" parameters |
| HashMap<String, Hop> paramHops = new HashMap<String,Hop>(); |
| |
| // -- construct hops for all input parameters |
| // -- store them in hashmap so that their "name"s are maintained |
| Hop pHop = null; |
| for ( String paramName : source.getVarParams().keySet() ) { |
| pHop = processExpression(source.getVarParam(paramName), null, hops); |
| paramHops.put(paramName, pHop); |
| } |
| |
| Hop currBuiltinOp = null; |
| |
| if (target == null) { |
| target = createTarget(source); |
| } |
| |
| // construct hop based on opcode |
| switch(source.getOpCode()) { |
| case CDF: |
| case INVCDF: |
| case QNORM: |
| case QT: |
| case QF: |
| case QCHISQ: |
| case QEXP: |
| case PNORM: |
| case PT: |
| case PF: |
| case PCHISQ: |
| case PEXP: |
| currBuiltinOp = constructDfHop(target.getName(), target.getDataType(), target.getValueType(), source.getOpCode(), paramHops); |
| break; |
| |
| case GROUPEDAGG: |
| currBuiltinOp = new ParameterizedBuiltinOp( |
| target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.GROUPEDAGG, paramHops); |
| break; |
| |
| case RMEMPTY: |
| currBuiltinOp = new ParameterizedBuiltinOp( |
| target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.RMEMPTY, paramHops); |
| break; |
| |
| case REPLACE: |
| currBuiltinOp = new ParameterizedBuiltinOp( |
| target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.REPLACE, paramHops); |
| break; |
| |
| case ORDER: |
| ArrayList<Hop> inputs = new ArrayList<Hop>(); |
| inputs.add(paramHops.get("target")); |
| inputs.add(paramHops.get("by")); |
| inputs.add(paramHops.get("decreasing")); |
| inputs.add(paramHops.get("index.return")); |
| |
| currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), ReOrgOp.SORT, inputs); |
| |
| break; |
| |
| case TRANSFORM: |
| currBuiltinOp = new ParameterizedBuiltinOp( |
| target.getName(), target.getDataType(), |
| target.getValueType(), ParamBuiltinOp.TRANSFORM, |
| paramHops); |
| break; |
| |
| case TRANSFORMAPPLY: |
| currBuiltinOp = new ParameterizedBuiltinOp( |
| target.getName(), target.getDataType(), |
| target.getValueType(), ParamBuiltinOp.TRANSFORMAPPLY, |
| paramHops); |
| break; |
| |
| case TRANSFORMDECODE: |
| currBuiltinOp = new ParameterizedBuiltinOp( |
| target.getName(), target.getDataType(), |
| target.getValueType(), ParamBuiltinOp.TRANSFORMDECODE, |
| paramHops); |
| break; |
| |
| case TRANSFORMMETA: |
| currBuiltinOp = new ParameterizedBuiltinOp( |
| target.getName(), target.getDataType(), |
| target.getValueType(), ParamBuiltinOp.TRANSFORMMETA, |
| paramHops); |
| break; |
| |
| case TOSTRING: |
| currBuiltinOp = new ParameterizedBuiltinOp( |
| target.getName(), target.getDataType(), |
| target.getValueType(), ParamBuiltinOp.TOSTRING, |
| paramHops); |
| break; |
| |
| default: |
| |
| LOG.error(source.printErrorLocation() + |
| "processParameterizedBuiltinFunctionExpression() -- Unknown operation: " |
| + source.getOpCode()); |
| |
| throw new ParseException(source.printErrorLocation() + |
| "processParameterizedBuiltinFunctionExpression() -- Unknown operation: " |
| + source.getOpCode()); |
| } |
| |
| setIdentifierParams(currBuiltinOp, source.getOutput()); |
| |
| currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| |
| return currBuiltinOp; |
| } |
| |
| /** |
| * Construct Hops from parse tree : Process ParameterizedExpression in a |
| * read/write/rand statement |
| * |
| * @throws ParseException |
| * @throws HopsException |
| */ |
| private Hop processDataExpression(DataExpression source, DataIdentifier target, |
| HashMap<String, Hop> hops) throws ParseException, HopsException { |
| |
| // this expression has multiple "named" parameters |
| HashMap<String, Hop> paramHops = new HashMap<String,Hop>(); |
| |
| // -- construct hops for all input parameters |
| // -- store them in hashmap so that their "name"s are maintained |
| Hop pHop = null; |
| for ( String paramName : source.getVarParams().keySet() ) { |
| pHop = processExpression(source.getVarParam(paramName), null, hops); |
| paramHops.put(paramName, pHop); |
| } |
| |
| Hop currBuiltinOp = null; |
| |
| if (target == null) { |
| target = createTarget(source); |
| } |
| |
| // construct hop based on opcode |
| switch(source.getOpCode()) { |
| case READ: |
| currBuiltinOp = new DataOp(target.getName(), target.getDataType(), target.getValueType(), DataOpTypes.PERSISTENTREAD, paramHops); |
| ((DataOp)currBuiltinOp).setFileName(((StringIdentifier)source.getVarParam(DataExpression.IO_FILENAME)).getValue()); |
| break; |
| |
| case WRITE: |
| String name = target.getName(); |
| currBuiltinOp = new DataOp( |
| target.getName(), target.getDataType(), target.getValueType(), DataOpTypes.PERSISTENTWRITE, hops.get(name), paramHops); |
| //MB: commented for dynamic write |
| /*Identifier ioFilename = (Identifier)source.getVarParam(DataExpression.IO_FILENAME); |
| if (!(ioFilename instanceof StringIdentifier)) { |
| LOG.error(source.printErrorLocation() + "processDataExpression():: Filename must be a constant string value"); |
| throw new ParseException(source.printErrorLocation() + "processDataExpression():: Filename must be a constant string value"); |
| } else { |
| ((DataOp)currBuiltinOp).setFileName(((StringIdentifier)ioFilename).getValue()); |
| }*/ |
| break; |
| |
| case RAND: |
| // We limit RAND_MIN, RAND_MAX, RAND_SPARSITY, RAND_SEED, and RAND_PDF to be constants |
| DataGenMethod method = (paramHops.get(DataExpression.RAND_MIN).getValueType()==ValueType.STRING) ? |
| DataGenMethod.SINIT : DataGenMethod.RAND; |
| currBuiltinOp = new DataGenOp(method, target, paramHops); |
| break; |
| |
| case MATRIX: |
| ArrayList<Hop> tmp = new ArrayList<Hop>(); |
| tmp.add( 0, paramHops.get(DataExpression.RAND_DATA) ); |
| tmp.add( 1, paramHops.get(DataExpression.RAND_ROWS) ); |
| tmp.add( 2, paramHops.get(DataExpression.RAND_COLS) ); |
| tmp.add( 3, paramHops.get(DataExpression.RAND_BY_ROW) ); |
| currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), ReOrgOp.RESHAPE, tmp); |
| break; |
| |
| |
| default: |
| LOG.error(source.printErrorLocation() + |
| "processDataExpression():: Unknown operation: " |
| + source.getOpCode()); |
| |
| throw new ParseException(source.printErrorLocation() + |
| "processDataExpression():: Unknown operation: " |
| + source.getOpCode()); |
| } |
| |
| //set identifier meta data (incl dimensions and blocksizes) |
| setIdentifierParams(currBuiltinOp, source.getOutput()); |
| if( source.getOpCode()==DataExpression.DataOp.READ ) |
| ((DataOp)currBuiltinOp).setInputBlockSizes(target.getRowsInBlock(), target.getColumnsInBlock()); |
| currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| |
| return currBuiltinOp; |
| } |
| |
| /** |
| * Construct HOps from parse tree: process BuiltinFunction Expressions in |
| * MultiAssignment Statements. For all other builtin function expressions, |
| * <code>processBuiltinFunctionExpression()</code> is used. |
| */ |
| private Hop processMultipleReturnBuiltinFunctionExpression(BuiltinFunctionExpression source, ArrayList<DataIdentifier> targetList, |
| HashMap<String, Hop> hops) throws ParseException { |
| |
| // Construct Hops for all inputs |
| ArrayList<Hop> inputs = new ArrayList<Hop>(); |
| inputs.add( processExpression(source.getFirstExpr(), null, hops) ); |
| if ( source.getSecondExpr() != null ) |
| inputs.add( processExpression(source.getSecondExpr(), null, hops) ); |
| if ( source.getThirdExpr() != null ) |
| inputs.add( processExpression(source.getThirdExpr(), null, hops) ); |
| |
| FunctionType ftype = FunctionType.MULTIRETURN_BUILTIN; |
| String nameSpace = DMLProgram.INTERNAL_NAMESPACE; |
| |
| // Create an array list to hold the outputs of this lop. |
| // Exact list of outputs are added based on opcode. |
| ArrayList<Hop> outputs = new ArrayList<Hop>(); |
| |
| // Construct Hop for current builtin function expression based on its type |
| Hop currBuiltinOp = null; |
| switch (source.getOpCode()) { |
| case QR: |
| case LU: |
| case EIGEN: |
| |
| // Number of outputs = size of targetList = #of identifiers in source.getOutputs |
| String[] outputNames = new String[targetList.size()]; |
| for ( int i=0; i < targetList.size(); i++ ) { |
| outputNames[i] = ((DataIdentifier)targetList.get(i)).getName(); |
| Hop output = new DataOp(outputNames[i], DataType.MATRIX, ValueType.DOUBLE, inputs.get(0), DataOpTypes.FUNCTIONOUTPUT, outputNames[i]); |
| outputs.add(output); |
| } |
| |
| // Create the hop for current function call |
| FunctionOp fcall = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), inputs, outputNames, outputs); |
| currBuiltinOp = fcall; |
| |
| break; |
| |
| default: |
| throw new ParseException("Invaid Opcode in DMLTranslator:processMultipleReturnBuiltinFunctionExpression(): " + source.getOpCode()); |
| } |
| |
| // set properties for created hops based on outputs of source expression |
| for ( int i=0; i < source.getOutputs().length; i++ ) { |
| setIdentifierParams( outputs.get(i), source.getOutputs()[i]); |
| outputs.get(i).setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| } |
| currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| |
| return currBuiltinOp; |
| } |
| |
| /** |
| * Construct Hops from parse tree : Process BuiltinFunction Expression in an |
| * assignment statement |
| * |
| * @throws ParseException |
| * @throws HopsException |
| */ |
| private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, DataIdentifier target, |
| HashMap<String, Hop> hops) throws ParseException, HopsException { |
| Hop expr = processExpression(source.getFirstExpr(), null, hops); |
| Hop expr2 = null; |
| if (source.getSecondExpr() != null) { |
| expr2 = processExpression(source.getSecondExpr(), null, hops); |
| } |
| Hop expr3 = null; |
| if (source.getThirdExpr() != null) { |
| expr3 = processExpression(source.getThirdExpr(), null, hops); |
| } |
| |
| Hop currBuiltinOp = null; |
| |
| if (target == null) { |
| target = createTarget(source); |
| } |
| |
| // Construct the hop based on the type of Builtin function |
| switch (source.getOpCode()) { |
| |
| case COLSUM: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, |
| Direction.Col, expr); |
| break; |
| |
| case COLMAX: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, |
| Direction.Col, expr); |
| break; |
| |
| case COLMIN: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, |
| Direction.Col, expr); |
| break; |
| |
| case COLMEAN: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, |
| Direction.Col, expr); |
| break; |
| |
| case COLSD: |
| // colStdDevs = sqrt(colVariances) |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), |
| target.getValueType(), AggOp.VAR, Direction.Col, expr); |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), |
| target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp); |
| break; |
| |
| case COLVAR: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), |
| target.getValueType(), AggOp.VAR, Direction.Col, expr); |
| break; |
| |
| case ROWSUM: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, |
| Direction.Row, expr); |
| break; |
| |
| case ROWMAX: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, |
| Direction.Row, expr); |
| break; |
| |
| case ROWINDEXMAX: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAXINDEX, |
| Direction.Row, expr); |
| break; |
| |
| case ROWINDEXMIN: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MININDEX, |
| Direction.Row, expr); |
| break; |
| |
| case ROWMIN: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, |
| Direction.Row, expr); |
| break; |
| |
| case ROWMEAN: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, |
| Direction.Row, expr); |
| break; |
| |
| case ROWSD: |
| // rowStdDevs = sqrt(rowVariances) |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), |
| target.getValueType(), AggOp.VAR, Direction.Row, expr); |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), |
| target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp); |
| break; |
| |
| case ROWVAR: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), |
| target.getValueType(), AggOp.VAR, Direction.Row, expr); |
| break; |
| |
| case NROW: |
| // If the dimensions are available at compile time, then create a LiteralOp (constant propagation) |
| // Else create a UnaryOp so that a control program instruction is generated |
| |
| long nRows = expr.getDim1(); |
| if (nRows == -1) { |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NROW, expr); |
| } |
| else { |
| currBuiltinOp = new LiteralOp(nRows); |
| } |
| break; |
| |
| case NCOL: |
| // If the dimensions are available at compile time, then create a LiteralOp (constant propagation) |
| // Else create a UnaryOp so that a control program instruction is generated |
| |
| long nCols = expr.getDim2(); |
| if (nCols == -1) { |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NCOL, expr); |
| } |
| else { |
| currBuiltinOp = new LiteralOp(nCols); |
| } |
| break; |
| case LENGTH: |
| long nRows2 = expr.getDim1(); |
| long nCols2 = expr.getDim2(); |
| /* |
| * If the dimensions are available at compile time, then create a LiteralOp (constant propagation) |
| * Else create a UnaryOp so that a control program instruction is generated |
| */ |
| if ((nCols2 == -1) || (nRows2 == -1)) { |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.LENGTH, expr); |
| } |
| else { |
| long lval = (nCols2 * nRows2); |
| currBuiltinOp = new LiteralOp(lval); |
| } |
| break; |
| |
| case SUM: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, |
| Direction.RowCol, expr); |
| break; |
| |
| case MEAN: |
| if ( expr2 == null ) { |
| // example: x = mean(Y); |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, |
| Direction.RowCol, expr); |
| } |
| else { |
| // example: x = mean(Y,W); |
| // stable weighted mean is implemented by using centralMoment with order = 0 |
| Hop orderHop = new LiteralOp(0); |
| currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp3.CENTRALMOMENT, expr, expr2, orderHop); |
| } |
| break; |
| |
| case SD: |
| // stdDev = sqrt(variance) |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), |
| target.getValueType(), AggOp.VAR, Direction.RowCol, expr); |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), |
| target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp); |
| break; |
| |
| case VAR: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), |
| target.getValueType(), AggOp.VAR, Direction.RowCol, expr); |
| break; |
| |
| case MIN: |
| //construct AggUnary for min(X) but BinaryOp for min(X,Y) |
| if( expr2 == null ) { |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| AggOp.MIN, Direction.RowCol, expr); |
| } |
| else { |
| currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MIN, |
| expr, expr2); |
| } |
| break; |
| |
| case MAX: |
| //construct AggUnary for max(X) but BinaryOp for max(X,Y) |
| if( expr2 == null ) { |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| AggOp.MAX, Direction.RowCol, expr); |
| } else { |
| currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MAX, |
| expr, expr2); |
| } |
| break; |
| |
| case PPRED: |
| String sop = ((StringIdentifier)source.getThirdExpr()).getValue(); |
| sop = sop.replace("\"", ""); |
| OpOp2 operation; |
| if ( sop.equalsIgnoreCase(">=") ) |
| operation = OpOp2.GREATEREQUAL; |
| else if ( sop.equalsIgnoreCase(">") ) |
| operation = OpOp2.GREATER; |
| else if ( sop.equalsIgnoreCase("<=") ) |
| operation = OpOp2.LESSEQUAL; |
| else if ( sop.equalsIgnoreCase("<") ) |
| operation = OpOp2.LESS; |
| else if ( sop.equalsIgnoreCase("==") ) |
| operation = OpOp2.EQUAL; |
| else if ( sop.equalsIgnoreCase("!=") ) |
| operation = OpOp2.NOTEQUAL; |
| else { |
| LOG.error(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED."); |
| throw new ParseException(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED."); |
| } |
| currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), operation, expr, expr2); |
| break; |
| |
| case PROD: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.PROD, |
| Direction.RowCol, expr); |
| break; |
| case TRACE: |
| currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.TRACE, |
| Direction.RowCol, expr); |
| break; |
| |
| case TRANS: |
| currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.ReOrgOp.TRANSPOSE, expr); |
| break; |
| |
| case REV: |
| currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.ReOrgOp.REV, expr); |
| break; |
| |
| case CBIND: |
| currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp2.CBIND, expr, expr2); |
| break; |
| |
| case RBIND: |
| currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp2.RBIND, expr, expr2); |
| break; |
| |
| case DIAG: |
| currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.ReOrgOp.DIAG, expr); |
| break; |
| |
| case TABLE: |
| |
| // Always a TertiaryOp is created for table(). |
| // - create a hop for weights, if not provided in the function call. |
| int numTableArgs = source._args.length; |
| |
| switch(numTableArgs) { |
| case 2: |
| case 4: |
| // example DML statement: F = ctable(A,B) or F = ctable(A,B,10,15) |
| // here, weight is interpreted as 1.0 |
| Hop weightHop = new LiteralOp(1.0); |
| // set dimensions |
| weightHop.setDim1(0); |
| weightHop.setDim2(0); |
| weightHop.setNnz(-1); |
| weightHop.setRowsInBlock(0); |
| weightHop.setColsInBlock(0); |
| |
| if ( numTableArgs == 2 ) |
| currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop); |
| else { |
| Hop outDim1 = processExpression(source._args[2], null, hops); |
| Hop outDim2 = processExpression(source._args[3], null, hops); |
| currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop, outDim1, outDim2); |
| } |
| break; |
| |
| case 3: |
| case 5: |
| // example DML statement: F = ctable(A,B,W) or F = ctable(A,B,W,10,15) |
| if (numTableArgs == 3) |
| currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3); |
| else { |
| Hop outDim1 = processExpression(source._args[3], null, hops); |
| Hop outDim2 = processExpression(source._args[4], null, hops); |
| currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3, outDim1, outDim2); |
| } |
| break; |
| |
| default: |
| throw new ParseException("Invalid number of arguments "+ numTableArgs + " to table() function."); |
| } |
| break; |
| |
| //data type casts |
| case CAST_AS_SCALAR: |
| currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp1.CAST_AS_SCALAR, expr); |
| break; |
| case CAST_AS_MATRIX: |
| currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), Hop.OpOp1.CAST_AS_MATRIX, expr); |
| break; |
| case CAST_AS_FRAME: |
| currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), Hop.OpOp1.CAST_AS_FRAME, expr); |
| break; |
| |
| //value type casts |
| case CAST_AS_DOUBLE: |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.DOUBLE, Hop.OpOp1.CAST_AS_DOUBLE, expr); |
| break; |
| case CAST_AS_INT: |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.INT, Hop.OpOp1.CAST_AS_INT, expr); |
| break; |
| case CAST_AS_BOOLEAN: |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, Hop.OpOp1.CAST_AS_BOOLEAN, expr); |
| break; |
| case ABS: |
| case SIN: |
| case COS: |
| case TAN: |
| case ASIN: |
| case ACOS: |
| case ATAN: |
| case SIGN: |
| case SQRT: |
| case EXP: |
| case ROUND: |
| case CEIL: |
| case FLOOR: |
| case CUMSUM: |
| case CUMPROD: |
| case CUMMIN: |
| case CUMMAX: |
| Hop.OpOp1 mathOp1; |
| switch (source.getOpCode()) { |
| case ABS: |
| mathOp1 = Hop.OpOp1.ABS; |
| break; |
| case SIN: |
| mathOp1 = Hop.OpOp1.SIN; |
| break; |
| case COS: |
| mathOp1 = Hop.OpOp1.COS; |
| break; |
| case TAN: |
| mathOp1 = Hop.OpOp1.TAN; |
| break; |
| case ASIN: |
| mathOp1 = Hop.OpOp1.ASIN; |
| break; |
| case ACOS: |
| mathOp1 = Hop.OpOp1.ACOS; |
| break; |
| case ATAN: |
| mathOp1 = Hop.OpOp1.ATAN; |
| break; |
| case SIGN: |
| mathOp1 = Hop.OpOp1.SIGN; |
| break; |
| case SQRT: |
| mathOp1 = Hop.OpOp1.SQRT; |
| break; |
| case EXP: |
| mathOp1 = Hop.OpOp1.EXP; |
| break; |
| case ROUND: |
| mathOp1 = Hop.OpOp1.ROUND; |
| break; |
| case CEIL: |
| mathOp1 = Hop.OpOp1.CEIL; |
| break; |
| case FLOOR: |
| mathOp1 = Hop.OpOp1.FLOOR; |
| break; |
| case CUMSUM: |
| mathOp1 = Hop.OpOp1.CUMSUM; |
| break; |
| case CUMPROD: |
| mathOp1 = Hop.OpOp1.CUMPROD; |
| break; |
| case CUMMIN: |
| mathOp1 = Hop.OpOp1.CUMMIN; |
| break; |
| case CUMMAX: |
| mathOp1 = Hop.OpOp1.CUMMAX; |
| break; |
| default: |
| |
| LOG.error(source.printErrorLocation() + |
| "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " |
| + source.getOpCode()); |
| |
| throw new ParseException(source.printErrorLocation() + |
| "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " |
| + source.getOpCode()); |
| } |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp1, expr); |
| break; |
| case LOG: |
| if (expr2 == null) { |
| Hop.OpOp1 mathOp2; |
| switch (source.getOpCode()) { |
| case LOG: |
| mathOp2 = Hop.OpOp1.LOG; |
| break; |
| default: |
| |
| LOG.error(source.printErrorLocation() + |
| "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " |
| + source.getOpCode()); |
| |
| throw new ParseException(source.printErrorLocation() + |
| "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " |
| + source.getOpCode()); |
| } |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp2, |
| expr); |
| } else { |
| Hop.OpOp2 mathOp3; |
| switch (source.getOpCode()) { |
| case LOG: |
| mathOp3 = Hop.OpOp2.LOG; |
| break; |
| default: |
| |
| LOG.error(source.printErrorLocation() + |
| "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " |
| + source.getOpCode()); |
| |
| throw new ParseException(source.printErrorLocation() + |
| "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " |
| + source.getOpCode()); |
| } |
| currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp3, |
| expr, expr2); |
| } |
| break; |
| case MOMENT: |
| if (expr3 == null){ |
| currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp2.CENTRALMOMENT, expr, expr2); |
| } |
| else { |
| currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp3.CENTRALMOMENT, expr, expr2,expr3); |
| } |
| break; |
| |
| case COV: |
| if (expr3 == null){ |
| currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp2.COVARIANCE, expr, expr2); |
| } |
| else { |
| currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp3.COVARIANCE, expr, expr2,expr3); |
| } |
| break; |
| |
| case QUANTILE: |
| if (expr3 == null){ |
| currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp2.QUANTILE, expr, expr2); |
| } |
| else { |
| currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp3.QUANTILE, expr, expr2,expr3); |
| } |
| break; |
| |
| case INTERQUANTILE: |
| if ( expr3 == null ) { |
| currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp2.INTERQUANTILE, expr, expr2); |
| } |
| else { |
| currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp3.INTERQUANTILE, expr, expr2,expr3); |
| } |
| break; |
| |
| case IQM: |
| if ( expr2 == null ) { |
| currBuiltinOp=new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp1.IQM, expr); |
| } |
| else { |
| currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp2.IQM, expr, expr2); |
| } |
| break; |
| |
| case MEDIAN: |
| if ( expr2 == null ) { |
| currBuiltinOp=new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp1.MEDIAN, expr); |
| } |
| else { |
| currBuiltinOp=new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp2.MEDIAN, expr, expr2); |
| } |
| break; |
| |
| case SEQ: |
| HashMap<String,Hop> randParams = new HashMap<String,Hop>(); |
| randParams.put(Statement.SEQ_FROM, expr); |
| randParams.put(Statement.SEQ_TO, expr2); |
| randParams.put(Statement.SEQ_INCR, (expr3!=null)?expr3 : new LiteralOp(1)); |
| //note incr: default -1 (for from>to) handled during runtime |
| currBuiltinOp = new DataGenOp(DataGenMethod.SEQ, target, randParams); |
| break; |
| |
| case SAMPLE: |
| { |
| Expression[] in = source.getAllExpr(); |
| |
| // arguments: range/size/replace/seed; defaults: replace=FALSE |
| |
| HashMap<String,Hop> tmpparams = new HashMap<String, Hop>(); |
| tmpparams.put(DataExpression.RAND_MAX, expr); //range |
| tmpparams.put(DataExpression.RAND_ROWS, expr2); |
| tmpparams.put(DataExpression.RAND_COLS, new LiteralOp(1)); |
| |
| if ( in.length == 4 ) |
| { |
| tmpparams.put(DataExpression.RAND_PDF, expr3); |
| Hop seed = processExpression(in[3], null, hops); |
| tmpparams.put(DataExpression.RAND_SEED, seed); |
| } |
| else if ( in.length == 3 ) |
| { |
| // check if the third argument is "replace" or "seed" |
| if ( expr3.getValueType() == ValueType.BOOLEAN ) |
| { |
| tmpparams.put(DataExpression.RAND_PDF, expr3); |
| tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) ); |
| } |
| else if ( expr3.getValueType() == ValueType.INT ) |
| { |
| tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false)); |
| tmpparams.put(DataExpression.RAND_SEED, expr3 ); |
| } |
| else |
| throw new HopsException("Invalid input type " + expr3.getValueType() + " in sample()."); |
| |
| } |
| else if ( in.length == 2 ) |
| { |
| tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false)); |
| tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) ); |
| } |
| |
| currBuiltinOp = new DataGenOp(DataGenMethod.SAMPLE, target, tmpparams); |
| break; |
| } |
| |
| case SOLVE: |
| currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.SOLVE, expr, expr2); |
| break; |
| |
| case INVERSE: |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp1.INVERSE, expr); |
| break; |
| |
| case CHOLESKY: |
| currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), |
| Hop.OpOp1.CHOLESKY, expr); |
| break; |
| |
| case OUTER: |
| if( !(expr3 instanceof LiteralOp) ) |
| throw new HopsException("Operator for outer builtin function must be a constant: "+expr3); |
| OpOp2 op = Hop.getOpOp2ForOuterVectorOperation(((LiteralOp)expr3).getStringValue()); |
| if( op == null ) |
| throw new HopsException("Unsupported outer vector binary operation: "+((LiteralOp)expr3).getStringValue()); |
| |
| currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, expr, expr2); |
| ((BinaryOp)currBuiltinOp).setOuterVectorOperation(true); //flag op as specific outer vector operation |
| currBuiltinOp.refreshSizeInformation(); //force size reevaluation according to 'outer' flag otherwise danger of incorrect dims |
| break; |
| |
| case CONV2D: |
| { |
| Hop image = expr; |
| ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops); |
| currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D, inHops1); |
| setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp); |
| break; |
| } |
| case AVG_POOL: |
| case MAX_POOL: |
| { |
| Hop image = expr; |
| ArrayList<Hop> inHops1 = getALHopsForPoolingForwardIM2COL(image, source, 1, hops); |
| if(source.getOpCode() == BuiltinFunctionOp.MAX_POOL) |
| currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING, inHops1); |
| else |
| throw new HopsException("Average pooling is not implemented"); |
| setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp); |
| break; |
| } |
| case MAX_POOL_BACKWARD: |
| { |
| Hop image = expr; |
| ArrayList<Hop> inHops1 = getALHopsForConvOpPoolingCOL2IM(image, source, 1, hops); // process dout as well |
| currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING_BACKWARD, inHops1); |
| setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp); |
| break; |
| } |
| case CONV2D_BACKWARD_FILTER: |
| { |
| Hop image = expr; |
| ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops); |
| currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER, inHops1); |
| setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp); |
| break; |
| } |
| case CONV2D_BACKWARD_DATA: |
| { |
| Hop image = expr; |
| ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops); |
| currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA, inHops1); |
| setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp); |
| break; |
| } |
| |
| default: |
| throw new ParseException("Unsupported builtin function type: "+source.getOpCode()); |
| } |
| |
| setIdentifierParams(currBuiltinOp, source.getOutput()); |
| currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); |
| return currBuiltinOp; |
| } |
| |
| private void setBlockSizeAndRefreshSizeInfo(Hop in, Hop out) { |
| HopRewriteUtils.setOutputBlocksizes(out, in.getRowsInBlock(), in.getColsInBlock()); |
| HopRewriteUtils.copyLineNumbers(in, out); |
| out.refreshSizeInformation(); |
| } |
| |
| private ArrayList<Hop> getALHopsForConvOpPoolingCOL2IM(Hop first, BuiltinFunctionExpression source, int skip, HashMap<String, Hop> hops) throws ParseException { |
| ArrayList<Hop> ret = new ArrayList<Hop>(); |
| ret.add(first); |
| Expression[] allExpr = source.getAllExpr(); |
| |
| for(int i = skip; i < allExpr.length; i++) { |
| if(i == 11) { |
| ret.add(processExpression(allExpr[7], null, hops)); // Make number of channels of images and filter the same |
| } |
| else |
| ret.add(processExpression(allExpr[i], null, hops)); |
| } |
| return ret; |
| } |
| |
| private ArrayList<Hop> getALHopsForPoolingForwardIM2COL(Hop first, BuiltinFunctionExpression source, int skip, HashMap<String, Hop> hops) throws ParseException { |
| ArrayList<Hop> ret = new ArrayList<Hop>(); |
| ret.add(first); |
| Expression[] allExpr = source.getAllExpr(); |
| if(skip != 1) { |
| throw new ParseException("Unsupported skip"); |
| } |
| |
| Expression numChannels = allExpr[6]; |
| |
| for(int i = skip; i < allExpr.length; i++) { |
| if(i == 10) { |
| ret.add(processExpression(numChannels, null, hops)); |
| } |
| else |
| ret.add(processExpression(allExpr[i], null, hops)); |
| } |
| return ret; |
| } |
| |
| @SuppressWarnings("unused") //TODO remove if not used |
| private ArrayList<Hop> getALHopsForConvOpPoolingIM2COL(Hop first, BuiltinFunctionExpression source, int skip, HashMap<String, Hop> hops) throws ParseException { |
| ArrayList<Hop> ret = new ArrayList<Hop>(); |
| ret.add(first); |
| Expression[] allExpr = source.getAllExpr(); |
| int numImgIndex = -1; |
| if(skip == 1) { |
| numImgIndex = 5; |
| } |
| else if(skip == 2) { |
| numImgIndex = 6; |
| } |
| else { |
| throw new ParseException("Unsupported skip"); |
| } |
| |
| for(int i = skip; i < allExpr.length; i++) { |
| if(i == numImgIndex) { // skip=1 ==> i==5 and skip=2 => i==6 |
| Expression numImg = allExpr[numImgIndex]; |
| Expression numChannels = allExpr[numImgIndex+1]; |
| BinaryExpression tmp = new BinaryExpression(org.apache.sysml.parser.Expression.BinaryOp.MULT, |
| numImg.getFilename(), numImg.getBeginLine(), numImg.getBeginColumn(), numImg.getEndLine(), numImg.getEndColumn()); |
| tmp.setLeft(numImg); |
| tmp.setRight(numChannels); |
| ret.add(processTempIntExpression(tmp, hops)); |
| ret.add(processExpression(new IntIdentifier(1, numImg.getFilename(), numImg.getBeginLine(), numImg.getBeginColumn(), |
| numImg.getEndLine(), numImg.getEndColumn()), null, hops)); |
| i++; |
| } |
| else |
| ret.add(processExpression(allExpr[i], null, hops)); |
| } |
| return ret; |
| } |
| |
| private ArrayList<Hop> getALHopsForConvOp(Hop first, BuiltinFunctionExpression source, int skip, HashMap<String, Hop> hops) throws ParseException { |
| ArrayList<Hop> ret = new ArrayList<Hop>(); |
| ret.add(first); |
| Expression[] allExpr = source.getAllExpr(); |
| for(int i = skip; i < allExpr.length; i++) { |
| ret.add(processExpression(allExpr[i], null, hops)); |
| } |
| return ret; |
| } |
| |
| public void setIdentifierParams(Hop h, Identifier id) { |
| if( id.getDim1()>= 0 ) |
| h.setDim1(id.getDim1()); |
| if( id.getDim2()>= 0 ) |
| h.setDim2(id.getDim2()); |
| if( id.getNnz()>= 0 ) |
| h.setNnz(id.getNnz()); |
| h.setRowsInBlock(id.getRowsInBlock()); |
| h.setColsInBlock(id.getColumnsInBlock()); |
| } |
| |
| public void setIdentifierParams(Hop h, Hop source) { |
| |
| h.setDim1(source.getDim1()); |
| h.setDim2(source.getDim2()); |
| h.setNnz(source.getNnz()); |
| h.setRowsInBlock(source.getRowsInBlock()); |
| h.setColsInBlock(source.getColsInBlock()); |
| } |
| |
| /** |
| * |
| * @param prog |
| * @param pWrites |
| * @throws LanguageException |
| */ |
| private boolean prepareReadAfterWrite( DMLProgram prog, HashMap<String, DataIdentifier> pWrites ) |
| throws LanguageException |
| { |
| boolean ret = false; |
| |
| //process functions |
| /*MB: for the moment we only support read-after-write in the main program |
| for( FunctionStatementBlock fsb : prog.getFunctionStatementBlocks() ) |
| ret |= prepareReadAfterWrite(fsb, pWrites); |
| */ |
| |
| //process main program |
| for( StatementBlock sb : prog.getStatementBlocks() ) |
| ret |= prepareReadAfterWrite(sb, pWrites); |
| |
| return ret; |
| } |
| |
| /** |
| * |
| * @param sb |
| * @param pWrites |
| */ |
| private boolean prepareReadAfterWrite( StatementBlock sb, HashMap<String, DataIdentifier> pWrites ) |
| { |
| boolean ret = false; |
| |
| if(sb instanceof FunctionStatementBlock) |
| { |
| FunctionStatementBlock fsb = (FunctionStatementBlock) sb; |
| FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); |
| for (StatementBlock csb : fstmt.getBody()) |
| ret |= prepareReadAfterWrite(csb, pWrites); |
| } |
| else if(sb instanceof WhileStatementBlock) |
| { |
| WhileStatementBlock wsb = (WhileStatementBlock) sb; |
| WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); |
| for (StatementBlock csb : wstmt.getBody()) |
| ret |= prepareReadAfterWrite(csb, pWrites); |
| } |
| else if(sb instanceof IfStatementBlock) |
| { |
| IfStatementBlock isb = (IfStatementBlock) sb; |
| IfStatement istmt = (IfStatement)isb.getStatement(0); |
| for (StatementBlock csb : istmt.getIfBody()) |
| ret |= prepareReadAfterWrite(csb, pWrites); |
| for (StatementBlock csb : istmt.getElseBody()) |
| ret |= prepareReadAfterWrite(csb, pWrites); |
| } |
| else if(sb instanceof ForStatementBlock) //incl parfor |
| { |
| ForStatementBlock fsb = (ForStatementBlock) sb; |
| ForStatement fstmt = (ForStatement)fsb.getStatement(0); |
| for (StatementBlock csb : fstmt.getBody()) |
| ret |= prepareReadAfterWrite(csb, pWrites); |
| } |
| else //generic (last-level) |
| { |
| for( Statement s : sb.getStatements() ) |
| { |
| //collect persistent write information |
| if( s instanceof OutputStatement ) |
| { |
| OutputStatement os = (OutputStatement) s; |
| String pfname = os.getExprParam(DataExpression.IO_FILENAME).toString(); |
| DataIdentifier di = (DataIdentifier) os.getSource().getOutput(); |
| pWrites.put(pfname, di); |
| } |
| //propagate size information into reads-after-write |
| else if( s instanceof AssignmentStatement |
| && ((AssignmentStatement)s).getSource() instanceof DataExpression ) |
| { |
| DataExpression dexpr = (DataExpression) ((AssignmentStatement)s).getSource(); |
| if( dexpr.isRead() ){ |
| String pfname = dexpr.getVarParam(DataExpression.IO_FILENAME).toString(); |
| if( pWrites.containsKey(pfname) && !pfname.trim().isEmpty() ) //found read-after-write |
| { |
| //update read with essential write meta data |
| DataIdentifier di = pWrites.get(pfname); |
| FormatType ft = (di.getFormatType()!=null) ? di.getFormatType() : FormatType.TEXT; |
| dexpr.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(ft.toString(),di.getFilename(),di.getBeginLine(),di.getBeginColumn(),di.getEndLine(),di.getEndColumn())); |
| if( di.getDim1()>=0 ) |
| dexpr.addVarParam(DataExpression.READROWPARAM, new IntIdentifier(di.getDim1(),di.getFilename(),di.getBeginLine(),di.getBeginColumn(),di.getEndLine(),di.getEndColumn())); |
| if( di.getDim2()>=0 ) |
| dexpr.addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(di.getDim2(),di.getFilename(),di.getBeginLine(),di.getBeginColumn(),di.getEndLine(),di.getEndColumn())); |
| if( di.getValueType()!=ValueType.UNKNOWN ) |
| dexpr.addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier(di.getValueType().toString(),di.getFilename(),di.getBeginLine(),di.getBeginColumn(),di.getEndLine(),di.getEndColumn())); |
| if( di.getDataType()!=DataType.UNKNOWN ) |
| dexpr.addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier(di.getDataType().toString(),di.getFilename(),di.getBeginLine(),di.getBeginColumn(),di.getEndLine(),di.getEndColumn())); |
| ret = true; |
| } |
| } |
| } |
| } |
| } |
| |
| return ret; |
| } |
| } |