blob: 1f9d5de8feeb8e8e39eac8bb08f3544bbc8ce71c [file] [log] [blame]
# -------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# -------------------------------------------------------------
from typing import Any, Collection, KeysView, Tuple, Union, Optional, Dict, TYPE_CHECKING
from py4j.java_collections import JavaArray
from py4j.java_gateway import JavaObject, JavaGateway
from systemds.script_building.dag import DAGNode, OutputType
from systemds.utils.consts import VALID_INPUT_TYPES
if TYPE_CHECKING:
# to avoid cyclic dependencies during runtime
from systemds.context import SystemDSContext
class DMLScript:
"""DMLScript is the class used to describe our intended behavior in DML. This script can be then executed to
get the results.
TODO caching
TODO rerun with different inputs without recompilation
"""
sds_context: 'SystemDSContext'
dml_script: str
inputs: Dict[str, DAGNode]
prepared_script: Optional[Any]
out_var_name: str
_variable_counter: int
def __init__(self, context: 'SystemDSContext') -> None:
self.sds_context = context
self.dml_script = ''
self.inputs = {}
self.prepared_script = None
self.out_var_name = []
self._variable_counter = 0
def add_code(self, code: str) -> None:
"""Add a dml code line to our script
:param code: the dml code line
"""
self.dml_script += code + '\n'
def add_input_from_python(self, var_name: str, input_var: DAGNode) -> None:
"""Add an input for our preparedScript. Should only be executed for data that is python local.
:param var_name: name of variable
:param input_var: the DAGNode object which has data
"""
self.inputs[var_name] = input_var
def execute(self, lineage: bool = False) -> Union[JavaObject, Tuple[JavaObject, str]]:
"""If not already created, create a preparedScript from our DMLCode, pass python local data to our prepared
script, then execute our script and return the resultVariables
:return: resultVariables of our execution
"""
# we could use the gateway directly, non defined functions will be automatically
# sent to the entry_point, but this is safer
gateway = self.sds_context.java_gateway
entry_point = gateway.entry_point
if self.prepared_script is None:
input_names = self.inputs.keys()
connection = entry_point.getConnection()
self.prepared_script = connection.prepareScript(
self.dml_script,
_list_to_java_array(gateway, input_names),
_list_to_java_array(gateway, self.out_var_name))
for (name, input_node) in self.inputs.items():
input_node.pass_python_data_to_prepared_script(
self.sds_context, name, self.prepared_script)
if lineage:
connection.setLineage(True)
try:
ret = self.prepared_script.executeScript()
except Exception as e:
self.sds_context.exception_and_close(e)
if lineage:
if len(self.out_var_name) == 1:
return ret, self.prepared_script.getLineageTrace(self.out_var_name[0])
else:
traces = []
for output in self.out_var_name:
traces.append(self.prepared_script.getLineageTrace(output))
return ret, traces
return ret
def get_lineage(self) -> str:
gateway = self.sds_context.java_gateway
entry_point = gateway.entry_point
if self.prepared_script is None:
input_names = self.inputs.keys()
connection = entry_point.getConnection()
self.prepared_script = connection.prepareScript(
self.dml_script,
_list_to_java_array(gateway, input_names),
_list_to_java_array(gateway, self.out_var_name))
for (name, input_node) in self.inputs.items():
input_node.pass_python_data_to_prepared_script(
gateway.jvm, name, self.prepared_script)
connection.setLineage(True)
self.prepared_script.executeScript()
if len(self.out_var_name) == 1:
return self.prepared_script.getLineageTrace(self.out_var_name[0])
else:
traces = []
for output in self.out_var_name:
traces.append(self.prepared_script.getLineageTrace(output))
return traces
def build_code(self, dag_root: DAGNode) -> None:
"""Builds code from our DAG
:param dag_root: the topmost operation of our DAG, result of operation will be output
"""
baseOutVarString = self._dfs_dag_nodes(dag_root)
if(dag_root.output_type != OutputType.NONE):
if(dag_root.number_of_outputs > 1):
self.out_var_name = []
for idx in range(dag_root.number_of_outputs):
self.add_code(
f'write({baseOutVarString}_{idx}, \'./tmp_{idx}\');')
self.out_var_name.append(f'{baseOutVarString}_{idx}')
else:
self.out_var_name.append(baseOutVarString)
self.add_code(f'write({baseOutVarString}, \'./tmp\');')
def _dfs_dag_nodes(self, dag_node: VALID_INPUT_TYPES) -> str:
"""Uses Depth-First-Search to create code from DAG
:param dag_node: current DAG node
:return: the variable name the current DAG node operation created
"""
if not isinstance(dag_node, DAGNode):
if isinstance(dag_node, bool):
return 'TRUE' if dag_node else 'FALSE'
return str(dag_node)
# for each node do the dfs operation and save the variable names in `input_var_names`
# get variable names of unnamed parameters
unnamed_input_vars = [self._dfs_dag_nodes(
input_node) for input_node in dag_node.unnamed_input_nodes]
# get variable names of named parameters
named_input_vars = {name: self._dfs_dag_nodes(input_node) for name, input_node in
dag_node.named_input_nodes.items()}
curr_var_name = self._next_unique_var()
if dag_node.is_python_local_data:
self.add_input_from_python(curr_var_name, dag_node)
code_line = dag_node.code_line(
curr_var_name, unnamed_input_vars, named_input_vars)
self.add_code(code_line)
return curr_var_name
def _next_unique_var(self) -> str:
"""Gets the next unique variable name
:return: the next variable name (id)
"""
var_id = self._variable_counter
self._variable_counter += 1
return f'V{var_id}'
# Helper Functions
def _list_to_java_array(gateway: JavaGateway, py_list: Union[Collection[str], KeysView[str]]) -> JavaArray:
"""Convert python collection to java array.
:param py_list: python collection
:return: java array
"""
array = gateway.new_array(gateway.jvm.java.lang.String, len(py_list))
for (i, e) in enumerate(py_list):
array[i] = e
return array