blob: 679253809969e495f55c8b287e9f1afa1d9d42a2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.controlprogram;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map.Entry;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.FunctionDictionary;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
public class Program
{
public static final String KEY_DELIM = "::";
private DMLProgram _prog;
private ArrayList<ProgramBlock> _programBlocks;
private HashMap<String, FunctionDictionary<FunctionProgramBlock>> _namespaces;
public Program() {
_namespaces = new HashMap<>();
_namespaces.put(DMLProgram.DEFAULT_NAMESPACE, new FunctionDictionary<>());
_programBlocks = new ArrayList<>();
}
public Program(DMLProgram prog) {
this();
setDMLProg(prog);
}
public void setDMLProg(DMLProgram prog) {
_prog = prog;
}
public DMLProgram getDMLProg() {
return _prog;
}
public synchronized void addFunctionProgramBlock(String namespace, String fname, FunctionProgramBlock fpb) {
addFunctionProgramBlock(namespace, fname, fpb, true);
}
public synchronized void addFunctionProgramBlock(String namespace, String fname, FunctionProgramBlock fpb, boolean opt) {
if( fpb == null )
throw new DMLRuntimeException("Invalid null function program block.");
namespace = getSafeNamespace(namespace);
FunctionDictionary<FunctionProgramBlock> dict = _namespaces.get(namespace);
if (dict == null)
_namespaces.put(namespace, dict = new FunctionDictionary<>());
dict.addFunction(fname, fpb, opt);
}
public synchronized void removeFunctionProgramBlock(String namespace, String fname) {
namespace = getSafeNamespace(namespace);
FunctionDictionary<?> dict = null;
if( _namespaces.containsKey(namespace) ){
dict = _namespaces.get(namespace);
if( dict.containsFunction(fname) )
dict.removeFunction(fname);
}
}
public HashMap<String,FunctionProgramBlock> getFunctionProgramBlocks(){
return getFunctionProgramBlocks(true);
}
public synchronized HashMap<String,FunctionProgramBlock> getFunctionProgramBlocks(boolean opt){
HashMap<String,FunctionProgramBlock> retVal = new HashMap<>();
for (Entry<String,FunctionDictionary<FunctionProgramBlock>> namespace : _namespaces.entrySet()){
if( namespace.getValue().getFunctions(opt) != null )
for( Entry<String, FunctionProgramBlock> e2 : namespace.getValue().getFunctions(opt).entrySet() ){
String fKey = DMLProgram.constructFunctionKey(namespace.getKey(), e2.getKey());
retVal.put(fKey, e2.getValue());
}
}
return retVal;
}
public synchronized boolean containsFunctionProgramBlock(String namespace, String fname) {
namespace = getSafeNamespace(namespace);
return _namespaces.containsKey(namespace)
&& _namespaces.get(namespace).containsFunction(fname);
}
public synchronized boolean containsFunctionProgramBlock(String fkey, boolean opt) {
String[] parts = DMLProgram.splitFunctionKey(fkey);
return containsFunctionProgramBlock(parts[0], parts[1], opt);
}
public synchronized boolean containsFunctionProgramBlock(String namespace, String fname, boolean opt) {
namespace = getSafeNamespace(namespace);
return _namespaces.containsKey(namespace)
&& _namespaces.get(namespace).containsFunction(fname, opt);
}
public synchronized FunctionProgramBlock getFunctionProgramBlock(String namespace, String fname) {
return getFunctionProgramBlock(namespace, fname, true);
}
public synchronized FunctionProgramBlock getFunctionProgramBlock(String fkey, boolean opt) {
String[] parts = DMLProgram.splitFunctionKey(fkey);
return getFunctionProgramBlock(parts[0], parts[1], opt);
}
public synchronized FunctionProgramBlock getFunctionProgramBlock(String namespace, String fname, boolean opt) {
namespace = getSafeNamespace(namespace);
FunctionDictionary<FunctionProgramBlock> dict = _namespaces.get(namespace);
if (dict == null)
throw new DMLRuntimeException("namespace " + namespace + " is undefined.");
FunctionProgramBlock retVal = dict.getFunction(fname, opt);
if (retVal == null)
throw new DMLRuntimeException("function " + fname + " ("+opt+") is undefined in namespace " + namespace);
return retVal;
}
public void addProgramBlock(ProgramBlock pb) {
_programBlocks.add(pb);
}
public ArrayList<ProgramBlock> getProgramBlocks() {
return _programBlocks;
}
public void execute(ExecutionContext ec) {
try{
for (int i=0; i<_programBlocks.size(); i++)
_programBlocks.get(i).execute(ec);
}
catch(DMLScriptException e) {
throw e;
}
catch(Exception e) {
throw new DMLRuntimeException(e);
}
}
public Program clone(boolean deep) {
if( deep )
throw new NotImplementedException();
Program ret = new Program(_prog);
//shallow copy of all program blocks
ret._programBlocks.addAll(_programBlocks);
//shallow copy of all functions, except external
//functions, which require a deep copy
for( Entry<String, FunctionDictionary<FunctionProgramBlock>> e1 : _namespaces.entrySet() )
for( Entry<String, FunctionProgramBlock> e2 : e1.getValue().getFunctions().entrySet() )
ret.addFunctionProgramBlock(e1.getKey(), e2.getKey(), e2.getValue());
return ret;
}
@Override
public Object clone() {
return clone(true);
}
private static String getSafeNamespace(String namespace) {
return (namespace == null) ? DMLProgram.DEFAULT_NAMESPACE : namespace;
}
}