| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.api.mlcontext; |
| |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.LinkedHashMap; |
| import java.util.LinkedHashSet; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Map.Entry; |
| |
| import org.apache.sysds.runtime.controlprogram.LocalVariableMap; |
| import org.apache.sysds.runtime.controlprogram.caching.CacheableData; |
| import org.apache.sysds.runtime.instructions.cp.Data; |
| |
| import java.util.Set; |
| |
| import scala.Tuple2; |
| import scala.Tuple3; |
| import scala.collection.JavaConversions; |
| |
| /** |
| * A Script object encapsulates a DML or PYDML script. |
| * |
| */ |
| public class Script { |
| /** |
| * The script content. |
| */ |
| private String scriptString; |
| /** |
| * The optional name of the script. |
| */ |
| private String name; |
| /** |
| * All inputs (input parameters ($) and input variables). |
| */ |
| private Map<String, Object> inputs = new LinkedHashMap<>(); |
| /** |
| * The input parameters ($). |
| */ |
| private Map<String, Object> inputParameters = new LinkedHashMap<>(); |
| /** |
| * The input variables. |
| */ |
| private Set<String> inputVariables = new LinkedHashSet<>(); |
| /** |
| * The input matrix or frame metadata if present. |
| */ |
| private Map<String, Metadata> inputMetadata = new LinkedHashMap<>(); |
| /** |
| * The output variables. |
| */ |
| private Set<String> outputVariables = new LinkedHashSet<>(); |
| /** |
| * The symbol table containing the data associated with variables. |
| */ |
| private LocalVariableMap symbolTable = new LocalVariableMap(); |
| /** |
| * The ScriptExecutor which is used to define the execution of the script. |
| */ |
| private ScriptExecutor scriptExecutor; |
| /** |
| * The results of the execution of the script. |
| */ |
| private MLResults results; |
| |
| /** |
| * Script constructor, which by default creates a DML script. |
| */ |
| public Script() { |
| |
| } |
| |
| /** |
| * Script constructor, specifying the script content. By default, the script |
| * type is DML. |
| * |
| * @param scriptString |
| * the script content as a string |
| */ |
| public Script(String scriptString) { |
| this.scriptString = scriptString; |
| } |
| |
| /** |
| * Obtain the script string. |
| * |
| * @return the script string |
| */ |
| public String getScriptString() { |
| return scriptString; |
| } |
| |
| /** |
| * Set the script string. |
| * |
| * @param scriptString |
| * the script string |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| public Script setScriptString(String scriptString) { |
| this.scriptString = scriptString; |
| return this; |
| } |
| |
| /** |
| * Obtain the input variable names as an unmodifiable set of strings. |
| * |
| * @return the input variable names |
| */ |
| public Set<String> getInputVariables() { |
| return Collections.unmodifiableSet(inputVariables); |
| } |
| |
| /** |
| * Obtain the output variable names as an unmodifiable set of strings. |
| * |
| * @return the output variable names |
| */ |
| public Set<String> getOutputVariables() { |
| return Collections.unmodifiableSet(outputVariables); |
| } |
| |
| /** |
| * Obtain the symbol table, which is essentially a |
| * {@code HashMap<String, Data>} representing variables and their values. |
| * |
| * @return the symbol table |
| */ |
| public LocalVariableMap getSymbolTable() { |
| return symbolTable; |
| } |
| |
| /** |
| * Obtain an unmodifiable map of all inputs (parameters ($) and variables). |
| * |
| * @return all inputs to the script |
| */ |
| public Map<String, Object> getInputs() { |
| return Collections.unmodifiableMap(inputs); |
| } |
| |
| /** |
| * Obtain an unmodifiable map of input matrix/frame metadata. |
| * |
| * @return input matrix/frame metadata |
| */ |
| public Map<String, Metadata> getInputMetadata() { |
| return Collections.unmodifiableMap(inputMetadata); |
| } |
| |
| /** |
| * Pass a map of inputs to the script. |
| * |
| * @param inputs |
| * map of inputs (parameters ($) and variables). |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| public Script in(Map<String, Object> inputs) { |
| for (Entry<String, Object> input : inputs.entrySet()) { |
| in(input.getKey(), input.getValue()); |
| } |
| |
| return this; |
| } |
| |
| /** |
| * Pass a Scala Map of inputs to the script. |
| * <p> |
| * Note that the {@code Map} value type is not explicitly specified on this |
| * method because {@code [String, Any]} can't be recognized on the Java side |
| * since {@code Any} doesn't have an equivalent in the Java class hierarchy |
| * ({@code scala.Any} is a superclass of {@code scala.AnyRef}, which is |
| * equivalent to {@code java.lang.Object}). Therefore, specifying |
| * {@code scala.collection.Map<String, Object>} as an input parameter to |
| * this Java method is not encompassing enough and would require types such |
| * as a {@code scala.Double} to be cast using {@code asInstanceOf[AnyRef]}. |
| * |
| * @param inputs |
| * Scala Map of inputs (parameters ($) and variables). |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| @SuppressWarnings({ "rawtypes", "unchecked" }) |
| public Script in(scala.collection.Map<String, ?> inputs) { |
| Map javaMap = JavaConversions.mapAsJavaMap(inputs); |
| in(javaMap); |
| |
| return this; |
| } |
| |
| /** |
| * Pass a Scala Seq of inputs to the script. The inputs are either two-value |
| * or three-value tuples, where the first value is the variable name, the |
| * second value is the variable value, and the third optional value is the |
| * metadata. |
| * |
| * @param inputs |
| * Scala Seq of inputs (parameters ($) and variables). |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| public Script in(scala.collection.Seq<Object> inputs) { |
| List<Object> list = JavaConversions.seqAsJavaList(inputs); |
| for (Object obj : list) { |
| if (obj instanceof Tuple3) { |
| @SuppressWarnings("unchecked") |
| Tuple3<String, Object, MatrixMetadata> t3 = (Tuple3<String, Object, MatrixMetadata>) obj; |
| in(t3._1(), t3._2(), t3._3()); |
| } else if (obj instanceof Tuple2) { |
| @SuppressWarnings("unchecked") |
| Tuple2<String, Object> t2 = (Tuple2<String, Object>) obj; |
| in(t2._1(), t2._2()); |
| } else { |
| throw new MLContextException("Only Tuples of 2 or 3 values are permitted"); |
| } |
| } |
| return this; |
| } |
| |
| /** |
| * Obtain an unmodifiable map of all input parameters ($). |
| * |
| * @return input parameters ($) |
| */ |
| public Map<String, Object> getInputParameters() { |
| return inputParameters; |
| } |
| |
| /** |
| * Register an input (parameter ($) or variable). |
| * |
| * @param name |
| * name of the input |
| * @param value |
| * value of the input |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| public Script in(String name, Object value) { |
| return in(name, value, null); |
| } |
| |
| /** |
| * Register an input (parameter ($) or variable) with optional matrix |
| * metadata. |
| * |
| * @param name |
| * name of the input |
| * @param value |
| * value of the input |
| * @param metadata |
| * optional matrix/frame metadata |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| public Script in(String name, Object value, Metadata metadata) { |
| |
| if ((value != null) && (value instanceof Long)) { |
| // convert Long to Integer since Long not a supported value type |
| Long lng = (Long) value; |
| value = lng.intValue(); |
| } else if ((value != null) && (value instanceof Float)) { |
| // convert Float to Double since Float not a supported value type |
| Float flt = (Float) value; |
| value = flt.doubleValue(); |
| } |
| |
| MLContextUtil.checkInputValueType(name, value); |
| if (inputs == null) { |
| inputs = new LinkedHashMap<>(); |
| } |
| inputs.put(name, value); |
| |
| if (name.startsWith("$")) { |
| MLContextUtil.checkInputParameterType(name, value); |
| if (inputParameters == null) { |
| inputParameters = new LinkedHashMap<>(); |
| } |
| inputParameters.put(name, value); |
| } else { |
| Data data = MLContextUtil.convertInputType(name, value, metadata); |
| if (data != null) { |
| // store input variable name and data |
| symbolTable.put(name, data); |
| inputVariables.add(name); |
| |
| // store matrix/frame meta data and disable variable cleanup |
| if (data instanceof CacheableData) { |
| if (metadata != null) |
| inputMetadata.put(name, metadata); |
| ((CacheableData<?>) data).enableCleanup(false); |
| } |
| } |
| } |
| return this; |
| } |
| |
| /** |
| * Register an output variable. |
| * |
| * @param outputName |
| * name of the output variable |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| public Script out(String outputName) { |
| outputVariables.add(outputName); |
| return this; |
| } |
| |
| /** |
| * Register output variables. |
| * |
| * @param outputNames |
| * names of the output variables |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| public Script out(String... outputNames) { |
| outputVariables.addAll(Arrays.asList(outputNames)); |
| return this; |
| } |
| |
| /** |
| * Register output variables. |
| * |
| * @param outputNames |
| * names of the output variables |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| public Script out(scala.collection.Seq<String> outputNames) { |
| List<String> list = JavaConversions.seqAsJavaList(outputNames); |
| outputVariables.addAll(list); |
| return this; |
| } |
| |
| /** |
| * Register output variables. |
| * |
| * @param outputNames |
| * names of the output variables |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| public Script out(List<String> outputNames) { |
| outputVariables.addAll(outputNames); |
| return this; |
| } |
| |
| /** |
| * Clear the inputs, outputs, and symbol table. |
| */ |
| public void clearIOS() { |
| clearInputs(); |
| clearOutputs(); |
| clearSymbolTable(); |
| } |
| |
| /** |
| * Clear the inputs and outputs, but not the symbol table. |
| */ |
| public void clearIO() { |
| clearInputs(); |
| clearOutputs(); |
| } |
| |
| /** |
| * Clear the script string, inputs, outputs, and symbol table. |
| */ |
| public void clearAll() { |
| scriptString = null; |
| clearIOS(); |
| } |
| |
| /** |
| * Clear the inputs. |
| */ |
| public void clearInputs() { |
| inputs.clear(); |
| inputParameters.clear(); |
| inputVariables.clear(); |
| inputMetadata.clear(); |
| } |
| |
| /** |
| * Clear the outputs. |
| */ |
| public void clearOutputs() { |
| outputVariables.clear(); |
| } |
| |
| /** |
| * Clear the symbol table. |
| */ |
| public void clearSymbolTable() { |
| symbolTable.removeAll(); |
| } |
| |
| /** |
| * Obtain the results of the script execution. |
| * |
| * @return the results of the script execution. |
| */ |
| public MLResults results() { |
| return results; |
| } |
| |
| /** |
| * Obtain the results of the script execution. |
| * |
| * @return the results of the script execution. |
| */ |
| public MLResults getResults() { |
| return results; |
| } |
| |
| /** |
| * Set the results of the script execution. |
| * |
| * @param results |
| * the results of the script execution. |
| */ |
| public void setResults(MLResults results) { |
| this.results = results; |
| } |
| |
| /** |
| * Obtain the script executor used by this Script. |
| * |
| * @return the ScriptExecutor used by this Script. |
| */ |
| public ScriptExecutor getScriptExecutor() { |
| return scriptExecutor; |
| } |
| |
| /** |
| * Set the ScriptExecutor used by this Script. |
| * |
| * @param scriptExecutor |
| * the script executor |
| */ |
| public void setScriptExecutor(ScriptExecutor scriptExecutor) { |
| this.scriptExecutor = scriptExecutor; |
| } |
| |
| /** |
| * Generate the script execution string, which adds read/load/write/save |
| * statements to the beginning and end of the script to execute. |
| * |
| * @return the script execution string |
| */ |
| public String getScriptExecutionString() { |
| StringBuilder sb = new StringBuilder(); |
| |
| Set<String> ins = getInputVariables(); |
| for (String in : ins) { |
| Object inValue = getInputs().get(in); |
| sb.append(in); |
| if (inValue instanceof String) { |
| String quotedString = MLContextUtil.quotedString((String) inValue); |
| sb.append(" = " + quotedString + ";\n"); |
| } else if (MLContextUtil.isBasicType(inValue)) { |
| sb.append(" = read('', data_type='scalar', value_type='" + MLContextUtil.getBasicTypeString(inValue) |
| + "');\n"); |
| } else if (MLContextUtil.doesSymbolTableContainFrameObject(symbolTable, in)) { |
| sb.append(" = read('', data_type='frame');\n"); |
| } else { |
| sb.append(" = read('');\n"); |
| } |
| } |
| |
| sb.append(getScriptString()); |
| if (!getScriptString().endsWith("\n")) { |
| sb.append("\n"); |
| } |
| |
| Set<String> outs = getOutputVariables(); |
| for (String out : outs) { |
| sb.append("write("); |
| sb.append(out); |
| sb.append(", '');\n"); |
| } |
| |
| return sb.toString(); |
| } |
| |
| @Override |
| public String toString() { |
| StringBuilder sb = new StringBuilder(); |
| |
| sb.append(MLContextUtil.displayInputs("Inputs", inputs, symbolTable)); |
| sb.append("\n"); |
| sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable)); |
| return sb.toString(); |
| } |
| |
| /** |
| * Display information about the script as a String. This consists of the |
| * script type, inputs, outputs, input parameters, input variables, output |
| * variables, the symbol table, the script string, and the script execution |
| * string. |
| * |
| * @return information about this script as a String |
| */ |
| public String info() { |
| StringBuilder sb = new StringBuilder(); |
| |
| sb.append(MLContextUtil.displayInputs("Inputs", inputs, symbolTable)); |
| sb.append("\n"); |
| sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable)); |
| sb.append("\n"); |
| sb.append(MLContextUtil.displayMap("Input Parameters", inputParameters)); |
| sb.append("\n"); |
| sb.append(MLContextUtil.displaySet("Input Variables", inputVariables)); |
| sb.append("\n"); |
| sb.append(MLContextUtil.displaySet("Output Variables", outputVariables)); |
| sb.append("\n"); |
| sb.append(MLContextUtil.displaySymbolTable("Symbol Table", symbolTable)); |
| sb.append("\nScript String:\n"); |
| sb.append(scriptString); |
| sb.append("\nScript Execution String:\n"); |
| sb.append(getScriptExecutionString()); |
| sb.append("\n"); |
| |
| return sb.toString(); |
| } |
| |
| /** |
| * Display the script inputs. |
| * |
| * @return the script inputs |
| */ |
| public String displayInputs() { |
| return MLContextUtil.displayInputs("Inputs", inputs, symbolTable); |
| } |
| |
| /** |
| * Display the script outputs. |
| * |
| * @return the script outputs as a String |
| */ |
| public String displayOutputs() { |
| return MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable); |
| } |
| |
| /** |
| * Display the script input parameters. |
| * |
| * @return the script input parameters as a String |
| */ |
| public String displayInputParameters() { |
| return MLContextUtil.displayMap("Input Parameters", inputParameters); |
| } |
| |
| /** |
| * Display the script input variables. |
| * |
| * @return the script input variables as a String |
| */ |
| public String displayInputVariables() { |
| return MLContextUtil.displaySet("Input Variables", inputVariables); |
| } |
| |
| /** |
| * Display the script output variables. |
| * |
| * @return the script output variables as a String |
| */ |
| public String displayOutputVariables() { |
| return MLContextUtil.displaySet("Output Variables", outputVariables); |
| } |
| |
| /** |
| * Display the script symbol table. |
| * |
| * @return the script symbol table as a String |
| */ |
| public String displaySymbolTable() { |
| return MLContextUtil.displaySymbolTable("Symbol Table", symbolTable); |
| } |
| |
| /** |
| * Obtain the script name. |
| * |
| * @return the script name |
| */ |
| public String getName() { |
| return name; |
| } |
| |
| /** |
| * Set the script name. |
| * |
| * @param name |
| * the script name |
| * @return {@code this} Script object to allow chaining of methods |
| */ |
| public Script setName(String name) { |
| this.name = name; |
| return this; |
| } |
| |
| /** |
| * Execute the script and return the results as an MLResults object. |
| * |
| * @return results as an MLResults object |
| */ |
| public MLResults execute() { |
| MLContext ml = MLContext.getActiveMLContext(); |
| if (ml == null) { |
| throw new MLContextException("No MLContext object exists. Please create one."); |
| } |
| return ml.execute(this); |
| } |
| } |