blob: 017d56c462ca10e7e9b438fe5ea026bc756354aa [file] [log] [blame]
# -------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# -------------------------------------------------------------
__all__ = ["SystemDSContext"]
import copy
import os
import time
from glob import glob
from queue import Empty, Queue
from subprocess import PIPE, Popen
from threading import Lock, Thread
from time import sleep
from typing import Dict, Iterable, Sequence, Tuple, Union
import numpy as np
from py4j.java_gateway import GatewayParameters, JavaGateway
from py4j.protocol import Py4JNetworkError
from systemds.utils.consts import VALID_INPUT_TYPES
from systemds.utils.helpers import get_module_dir
from systemds.operator import OperationNode
class SystemDSContext(object):
"""A context with a connection to a java instance with which SystemDS operations are executed.
The java process is started and is running using a random tcp port for instruction parsing."""
java_gateway: JavaGateway
__stdout: Queue
__stderr: Queue
def __init__(self):
"""Starts a new instance of SystemDSContext, in which the connection to a JVM systemds instance is handled
Any new instance of this SystemDS Context, would start a separate new JVM.
Standard out and standard error form the JVM is also handled in this class, filling up Queues,
that can be read from to get the printed statements from the JVM.
"""
root = os.environ.get("SYSTEMDS_ROOT")
if root == None:
# If there is no systemds install default to use the PIP packaged java files.
root = os.path.join(get_module_dir(), "systemds-java")
# nt means its Windows
cp_separator = ";" if os.name == "nt" else ":"
if os.environ.get("SYSTEMDS_ROOT") != None:
lib_cp = os.path.join(root, "target", "lib", "*")
systemds_cp = os.path.join(root, "target", "SystemDS.jar")
classpath = cp_separator.join([lib_cp, systemds_cp])
command = ["java", "-cp", classpath]
files = glob(os.path.join(root, "conf", "log4j*.properties"))
if len(files) > 1:
print(
"WARNING: Multiple logging files found selecting: " + files[0])
if len(files) == 0:
print("WARNING: No log4j file found at: "
+ os.path.join(root, "conf")
+ " therefore using default settings")
else:
command.append("-Dlog4j.configuration=file:" + files[0])
else:
lib_cp = os.path.join(root, "lib", "*")
command = ["java", "-cp", lib_cp]
command.append("org.apache.sysds.api.PythonDMLScript")
# TODO add an argument parser here
# Find a random port, and hope that no other process
# steals it while we wait for the JVM to startup
port = self.__get_open_port()
command.append(str(port))
process = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE)
first_stdout = process.stdout.readline()
if(not b"GatewayServer Started" in first_stdout):
stderr = process.stderr.readline().decode("utf-8")
if(len(stderr) > 1):
raise Exception(
"Exception in startup of GatewayServer: " + stderr)
outputs = []
outputs.append(first_stdout.decode("utf-8"))
max_tries = 10
for i in range(max_tries):
next_line = process.stdout.readline()
if(b"GatewayServer Started" in next_line):
print("WARNING: Stdout corrupted by prints: " + str(outputs))
print("Startup success")
break
else:
outputs.append(next_line)
if (i == max_tries-1):
raise Exception("Error in startup of systemDS gateway process: \n gateway StdOut: " + str(
outputs) + " \n gateway StdErr" + process.stderr.readline().decode("utf-8"))
# Handle Std out from the subprocess.
self.__stdout = Queue()
self.__stderr = Queue()
Thread(target=self.__enqueue_output, args=(
process.stdout, self.__stdout), daemon=True).start()
Thread(target=self.__enqueue_output, args=(
process.stderr, self.__stderr), daemon=True).start()
# Py4j connect to the started process.
gwp = GatewayParameters(port=port, eager_load=True)
self.java_gateway = JavaGateway(
gateway_parameters=gwp, java_process=process)
def get_stdout(self, lines: int = -1):
"""Getter for the stdout of the java subprocess
The output is taken from the stdout queue and returned in a new list.
:param lines: The number of lines to try to read from the stdout queue.
default -1 prints all current lines in the queue.
"""
if lines == -1 or self.__stdout.qsize() < lines:
return [self.__stdout.get() for x in range(self.__stdout.qsize())]
else:
return [self.__stdout.get() for x in range(lines)]
def get_stderr(self, lines: int = -1):
"""Getter for the stderr of the java subprocess
The output is taken from the stderr queue and returned in a new list.
:param lines: The number of lines to try to read from the stderr queue.
default -1 prints all current lines in the queue.
"""
if lines == -1 or self.__stderr.qsize() < lines:
return [self.__stderr.get() for x in range(self.__stderr.qsize())]
else:
return [self.__stderr.get() for x in range(lines)]
def exception_and_close(self, e):
"""
Method for printing exception, printing stdout and error, while also closing the context correctly.
"""
# e = sys.exc_info()[0]
print("Exception Encountered! closing JVM")
print("standard out:")
[print(x) for x in self.get_stdout()]
print("standard error")
[print(x) for x in self.get_stderr()]
print("exception")
print(e)
self.close()
exit()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
# no errors to handle to allow continuation
return None
def close(self):
"""Close the connection to the java process and do necessary cleanup."""
process: Popen = self.java_gateway.java_process
self.java_gateway.shutdown()
# Send SigTerm
os.kill(process.pid, 14)
def __enqueue_output(self, out, queue):
"""Method for handling the output from java.
It is locating the string handeling inside a different thread, since the 'out.readline' is a blocking command.
"""
for line in iter(out.readline, b""):
queue.put(line.decode("utf-8").strip())
def __get_open_port(self):
"""Get a random available port."""
# TODO Verify that it is not taking some critical ports change to select a good port range.
# TODO If it tries to select a port already in use, find another.
# https://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
def full(self, shape: Tuple[int, int], value: Union[float, int]) -> 'OperationNode':
"""Generates a matrix completely filled with a value
:param sds_context: SystemDS context
:param shape: shape (rows and cols) of the matrix TODO tensor
:param value: the value to fill all cells with
:return: the OperationNode representing this operation
"""
unnamed_input_nodes = [value]
named_input_nodes = {'rows': shape[0], 'cols': shape[1]}
return OperationNode(self, 'matrix', unnamed_input_nodes, named_input_nodes)
def seq(self, start: Union[float, int], stop: Union[float, int] = None,
step: Union[float, int] = 1) -> 'OperationNode':
"""Create a single column vector with values from `start` to `stop` and an increment of `step`.
If no stop is defined and only one parameter is given, then start will be 0 and the parameter will be interpreted as
stop.
:param sds_context: SystemDS context
:param start: the starting value
:param stop: the maximum value
:param step: the step size
:return: the OperationNode representing this operation
"""
if stop is None:
stop = start
start = 0
unnamed_input_nodes = [start, stop, step]
return OperationNode(self, 'seq', unnamed_input_nodes)
def rand(self, rows: int, cols: int,
min: Union[float, int] = None, max: Union[float, int] = None, pdf: str = "uniform",
sparsity: Union[float, int] = None, seed: Union[float, int] = None,
lambd: Union[float, int] = 1) -> 'OperationNode':
"""Generates a matrix filled with random values
:param sds_context: SystemDS context
:param rows: number of rows
:param cols: number of cols
:param min: min value for cells
:param max: max value for cells
:param pdf: "uniform"/"normal"/"poison" distribution
:param sparsity: fraction of non-zero cells
:param seed: random seed
:param lambd: lamda value for "poison" distribution
:return:
"""
available_pdfs = ["uniform", "normal", "poisson"]
if rows < 0:
raise ValueError("In rand statement, can only assign rows a long (integer) value >= 0 "
"-- attempted to assign value: {r}".format(r=rows))
if cols < 0:
raise ValueError("In rand statement, can only assign cols a long (integer) value >= 0 "
"-- attempted to assign value: {c}".format(c=cols))
if pdf not in available_pdfs:
raise ValueError("The pdf passed is invalid! given: {g}, expected: {e}".format(
g=pdf, e=available_pdfs))
pdf = '\"' + pdf + '\"'
named_input_nodes = {
'rows': rows, 'cols': cols, 'pdf': pdf, 'lambda': lambd}
if min is not None:
named_input_nodes['min'] = min
if max is not None:
named_input_nodes['max'] = max
if sparsity is not None:
named_input_nodes['sparsity'] = sparsity
if seed is not None:
named_input_nodes['seed'] = seed
return OperationNode(self, 'rand', [], named_input_nodes=named_input_nodes)
def read(self, path: os.PathLike, **kwargs: Dict[str, VALID_INPUT_TYPES]):
return OperationNode(self, 'read', [f'"{path}"'], named_input_nodes=kwargs)