| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysml.yarn.ropt; |
| |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.HashSet; |
| |
| import org.apache.sysml.hops.Hop; |
| import org.apache.sysml.hops.HopsException; |
| import org.apache.sysml.hops.OptimizerUtils; |
| import org.apache.sysml.parser.StatementBlock; |
| import org.apache.sysml.runtime.DMLRuntimeException; |
| import org.apache.sysml.runtime.controlprogram.ForProgramBlock; |
| import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; |
| import org.apache.sysml.runtime.controlprogram.IfProgramBlock; |
| import org.apache.sysml.runtime.controlprogram.ProgramBlock; |
| import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; |
| |
| public class GridEnumerationMemory extends GridEnumeration |
| { |
| |
| public static final int DEFAULT_NSTEPS = 20; //old config: 15 |
| public static final int DEFAULT_MEM_ADD = 1*1024*1024; //1MB |
| |
| private int _nsteps = -1; |
| |
| public GridEnumerationMemory( ArrayList<ProgramBlock> prog, long min, long max ) |
| throws DMLRuntimeException |
| { |
| super(prog, min, max); |
| |
| _nsteps = DEFAULT_NSTEPS; |
| } |
| |
| /** |
| * |
| * @param steps |
| */ |
| public void setNumSteps( int steps ) |
| { |
| _nsteps = steps; |
| } |
| |
| @Override |
| public ArrayList<Long> enumerateGridPoints() |
| throws DMLRuntimeException, HopsException |
| { |
| ArrayList<Long> ret = new ArrayList<Long>(); |
| long gap = (long)(_max - _min) / (_nsteps-1); |
| |
| //get memory estimates |
| ArrayList<Long> mem = new ArrayList<Long>(); |
| getMemoryEstimates( _prog, mem ); |
| |
| //binning memory estimates to equi grid |
| HashSet<Long> preRet = new HashSet<Long>(); |
| for( Long val : mem ) |
| { |
| if( val < _min ) |
| preRet.add( _min ); //only right side |
| else if( val > _max ) |
| preRet.add( _max ); //only left side |
| else |
| { |
| long bin = Math.max((val-_min)/gap,0); |
| preRet.add( filterMax(_min + bin*gap) ); |
| preRet.add( filterMax(_min + (bin+1)*gap) ); |
| } |
| } |
| |
| //create sorted output (to prevent over-provisioning) |
| for( Long val : preRet ) |
| ret.add(val); |
| Collections.sort(ret); //asc |
| |
| return ret; |
| } |
| |
| private long filterMax( long val ) |
| { |
| if( val > _max ) //truncate max |
| return _max; |
| |
| return val; |
| } |
| |
| /** |
| * |
| * @param pbs |
| * @param mem |
| * @throws HopsException |
| */ |
| private void getMemoryEstimates( ArrayList<ProgramBlock> pbs, ArrayList<Long> mem ) |
| throws HopsException |
| { |
| for( ProgramBlock pb : pbs ) |
| getMemoryEstimates(pb, mem); |
| } |
| |
| /** |
| * |
| * @param pb |
| * @param mem |
| * @throws HopsException |
| */ |
| private void getMemoryEstimates( ProgramBlock pb, ArrayList<Long> mem ) |
| throws HopsException |
| { |
| if (pb instanceof FunctionProgramBlock) |
| { |
| FunctionProgramBlock fpb = (FunctionProgramBlock)pb; |
| getMemoryEstimates(fpb.getChildBlocks(), mem); |
| } |
| else if (pb instanceof WhileProgramBlock) |
| { |
| WhileProgramBlock fpb = (WhileProgramBlock)pb; |
| getMemoryEstimates(fpb.getChildBlocks(), mem); |
| } |
| else if (pb instanceof IfProgramBlock) |
| { |
| IfProgramBlock fpb = (IfProgramBlock)pb; |
| getMemoryEstimates(fpb.getChildBlocksIfBody(), mem); |
| getMemoryEstimates(fpb.getChildBlocksElseBody(), mem); |
| } |
| else if (pb instanceof ForProgramBlock) //incl parfor |
| { |
| ForProgramBlock fpb = (ForProgramBlock)pb; |
| getMemoryEstimates(fpb.getChildBlocks(), mem); |
| } |
| else |
| { |
| StatementBlock sb = pb.getStatementBlock(); |
| if( sb != null && sb.get_hops() != null ){ |
| Hop.resetVisitStatus(sb.get_hops()); |
| for( Hop hop : sb.get_hops() ) |
| getMemoryEstimates(hop, mem); |
| } |
| } |
| } |
| |
| /** |
| * |
| * @param hop |
| * @param mem |
| */ |
| private void getMemoryEstimates( Hop hop, ArrayList<Long> mem ) |
| { |
| if( hop.getVisited() == Hop.VisitStatus.DONE ) |
| return; |
| |
| //process childs |
| for(Hop hi : hop.getInput()) |
| getMemoryEstimates(hi, mem); |
| |
| //add memory estimates (scaled by CP memory ratio) |
| mem.add( (long)( (hop.getMemEstimate()+DEFAULT_MEM_ADD) |
| /OptimizerUtils.MEM_UTIL_FACTOR) ); |
| |
| hop.setVisited(Hop.VisitStatus.DONE); |
| } |
| |
| } |