blob: 878a5fc437ea31806e166119bd692ed35fc6b8af [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.api.jmlc;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
/**
* Utility class containing static methods for working with JMLC.
*
*/
public class JMLCUtils
{
/**
* Removes rmvar instructions that would remove any of the given outputs.
* This is important for keeping registered outputs after the program terminates.
*
* @param prog the DML/PyDML program
* @param outputs registered output variables
*/
public static void cleanupRuntimeProgram( Program prog, String[] outputs) {
Map<String, FunctionProgramBlock> funcMap = prog.getFunctionProgramBlocks();
HashSet<String> excludeList = new HashSet<>(Arrays.asList(outputs));
if( funcMap != null && !funcMap.isEmpty() ) {
for( Entry<String, FunctionProgramBlock> e : funcMap.entrySet() ) {
FunctionProgramBlock fpb = e.getValue();
for( ProgramBlock pb : fpb.getChildBlocks() )
rCleanupRuntimeProgram(pb, excludeList);
}
}
for( ProgramBlock pb : prog.getProgramBlocks() )
rCleanupRuntimeProgram(pb, excludeList);
}
/**
* Cleanup program blocks (called recursively).
*
* @param pb program block
* @param outputs registered output variables
*/
public static void rCleanupRuntimeProgram( ProgramBlock pb, HashSet<String> outputs ) {
if( pb instanceof WhileProgramBlock ) {
WhileProgramBlock wpb = (WhileProgramBlock)pb;
for( ProgramBlock pbc : wpb.getChildBlocks() )
rCleanupRuntimeProgram(pbc,outputs);
}
else if( pb instanceof IfProgramBlock ) {
IfProgramBlock ipb = (IfProgramBlock)pb;
for( ProgramBlock pbc : ipb.getChildBlocksIfBody() )
rCleanupRuntimeProgram(pbc,outputs);
for( ProgramBlock pbc : ipb.getChildBlocksElseBody() )
rCleanupRuntimeProgram(pbc,outputs);
}
else if( pb instanceof ForProgramBlock ) {
ForProgramBlock fpb = (ForProgramBlock)pb;
for( ProgramBlock pbc : fpb.getChildBlocks() )
rCleanupRuntimeProgram(pbc,outputs);
}
else if( pb instanceof BasicProgramBlock ) {
BasicProgramBlock bpb = (BasicProgramBlock) pb;
bpb.setInstructions(cleanupRuntimeInstructions(
bpb.getInstructions(), outputs));
}
}
/**
* Cleanup runtime instructions, removing rmvar instructions for
* any of the given output variable names.
*
* @param insts list of instructions
* @param outputs registered output variables
* @return list of instructions
*/
public static ArrayList<Instruction> cleanupRuntimeInstructions( ArrayList<Instruction> insts, String... outputs ) {
return cleanupRuntimeInstructions(insts, new HashSet<>(Arrays.asList(outputs)));
}
/**
* Cleanup runtime instructions, removing rmvar instructions for
* any of the given output variable names.
*
* @param insts list of instructions
* @param outputs registered output variables
* @return list of instructions
*/
public static ArrayList<Instruction> cleanupRuntimeInstructions( ArrayList<Instruction> insts, HashSet<String> outputs ) {
ArrayList<Instruction> ret = new ArrayList<>();
for( Instruction inst : insts ) {
if( inst instanceof VariableCPInstruction && ((VariableCPInstruction)inst).isRemoveVariable() ) {
ArrayList<String> currRmVar = new ArrayList<>();
for( CPOperand input : ((VariableCPInstruction)inst).getInputs() )
if( !outputs.contains(input.getName()) )
currRmVar.add(input.getName());
if( !currRmVar.isEmpty() ) {
ret.add(VariableCPInstruction.prepareRemoveInstruction(
currRmVar.toArray(new String[0])));
}
}
else
ret.add(inst);
}
return ret;
}
}