blob: 9ba3790537c192241f7e8fad5cdc76405d7c4bd6 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# To run:
# - Python 2: `PYSPARK_PYTHON=python2 spark-submit --master local[*] --driver-class-path SystemML.jar test_mlcontext.py`
# - Python 3: `PYSPARK_PYTHON=python3 spark-submit --master local[*] --driver-class-path SystemML.jar test_mlcontext.py`
# Make the `systemml` package importable
import os
import sys
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
sys.path.insert(0, path)
import unittest
import numpy as np
from pyspark.context import SparkContext
from systemml import MLContext, dml, pydml
sc = SparkContext()
ml = MLContext(sc)
class TestAPI(unittest.TestCase):
def test_output_string(self):
script = dml("x1 = 'Hello World'").output("x1")
self.assertEqual(ml.execute(script).get("x1"), "Hello World")
def test_output_list(self):
script = """
x1 = 0.2
x2 = x1 + 1
x3 = x1 + 2
"""
script = dml(script).output("x1", "x2", "x3")
self.assertEqual(ml.execute(script).get("x1", "x2"), [0.2, 1.2])
self.assertEqual(ml.execute(script).get("x1", "x3"), [0.2, 2.2])
def test_output_matrix(self):
sums = """
s1 = sum(m1)
m2 = m1 * 2
"""
rdd1 = sc.parallelize(["1.0,2.0", "3.0,4.0"])
script = dml(sums).input(m1=rdd1).output("s1", "m2")
s1, m2 = ml.execute(script).get("s1", "m2")
self.assertEqual((s1, repr(m2)), (10.0, "Matrix"))
def test_matrix_toDF(self):
sums = """
s1 = sum(m1)
m2 = m1 * 2
"""
rdd1 = sc.parallelize(["1.0,2.0", "3.0,4.0"])
script = dml(sums).input(m1=rdd1).output("m2")
m2 = ml.execute(script).get("m2")
self.assertEqual(repr(m2.toDF()), "DataFrame[__INDEX: double, C1: double, C2: double]")
def test_matrix_toNumPy(self):
script = """
m2 = m1 * 2
"""
rdd1 = sc.parallelize(["1.0,2.0", "3.0,4.0"])
script = dml(script).input(m1=rdd1).output("m2")
m2 = ml.execute(script).get("m2")
self.assertTrue((m2.toNumPy() == np.array([[2.0, 4.0], [6.0, 8.0]])).all())
def test_input_single(self):
script = """
x2 = x1 + 1
x3 = x1 + 2
"""
script = dml(script).input("x1", 5).output("x2", "x3")
self.assertEqual(ml.execute(script).get("x2", "x3"), [6, 7])
def test_input(self):
script = """
x3 = x1 + x2
"""
script = dml(script).input(x1=5, x2=3).output("x3")
self.assertEqual(ml.execute(script).get("x3"), 8)
def test_rdd(self):
sums = """
s1 = sum(m1)
s2 = sum(m2)
s3 = 'whatever'
"""
rdd1 = sc.parallelize(["1.0,2.0", "3.0,4.0"])
rdd2 = sc.parallelize(["5.0,6.0", "7.0,8.0"])
script = dml(sums).input(m1=rdd1).input(m2=rdd2).output("s1", "s2", "s3")
self.assertEqual(ml.execute(script).get("s1", "s2", "s3"), [10.0, 26.0, "whatever"])
def test_pydml(self):
script = "A = full('1 2 3 4 5 6 7 8 9', rows=3, cols=3)\nx = toString(A)"
script = pydml(script).output("x")
self.assertEqual(
ml.execute(script).get("x"),
'1.000 2.000 3.000\n4.000 5.000 6.000\n7.000 8.000 9.000\n'
)
if __name__ == "__main__":
unittest.main()