blob: c130cf1ff6b9db2d38f83b91af3b90ab09853400 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from contextlib import contextmanager
import os
import textwrap
from typing import Any, BinaryIO, Callable, Iterator
import unittest
from parameterized import parameterized
from pyspark import cloudpickle
from pyspark.ml.dl_util import FunctionPickler
class TestFunctionPickler(unittest.TestCase):
# Function that will be used to test pickling.
@staticmethod
def _test_function(x: float, y: float) -> float:
return x**2 + y**2
def _check_if_test_function_pickled(
self,
file: BinaryIO,
desired_function: Callable,
output_value: Any,
*arguments,
**key_word_args,
):
fn, args, kwargs = cloudpickle.load(file)
self.assertEqual(fn, desired_function)
self.assertEqual(args, arguments)
self.assertEqual(kwargs, key_word_args)
fn_output = fn(*args, **kwargs)
self.assertEqual(fn_output, output_value)
@parameterized.expand(
[
("See if it pickles correctly with no path specified", "", ""),
("See if it pickles correctly with path specified", "silly_bear", ""),
(
"See if it pickles correctly with both path and save_dir specified",
"silly_bear",
"tmp_dir",
),
]
)
def test_pickle_fn_and_save(self, _: str, file_path_to_save: str, save_dir: str):
x, y = 1, 3 # args of test_function
if save_dir != "":
os.makedirs(save_dir, exist_ok=True)
pickled_fn_path = FunctionPickler.pickle_fn_and_save(
TestFunctionPickler._test_function, file_path_to_save, save_dir, x, y
)
if file_path_to_save != "":
self.assertEqual(file_path_to_save, pickled_fn_path)
with open(pickled_fn_path, "rb") as f:
self._check_if_test_function_pickled(f, TestFunctionPickler._test_function, 10, x, y)
os.remove(pickled_fn_path)
if save_dir != "":
os.rmdir(save_dir)
def test_getting_output_from_pickle_file(self):
a, b = 2, 0 # arguments for _test_function
pickle_fn_file = FunctionPickler.pickle_fn_and_save(
TestFunctionPickler._test_function, "", "", a, b
)
fn, args, kwargs = FunctionPickler.get_fn_output(pickle_fn_file)
self.assertEqual(fn, TestFunctionPickler._test_function)
self.assertEqual(len(args), 2)
self.assertEqual(len(kwargs), 0)
self.assertEqual(args[0], a)
self.assertEqual(args[1], b)
self.assertEqual(fn(*args, **kwargs), 4)
os.remove(pickle_fn_file)
@contextmanager
def create_reference_file(
self, body: str, prefix: str = "", suffix: str = "", fname: str = "reference.py"
) -> Iterator[None]:
try:
with open(fname, "w") as f:
if prefix != "":
f.write(prefix)
f.write(body)
if suffix != "":
f.write(suffix)
yield
finally:
os.remove(fname)
def _create_code_snippet_body(self, pickled_fn_path: str, fn_output_save_path: str) -> str:
code_snippet = textwrap.dedent(
f"""
from pyspark import cloudpickle
import os
if __name__ == "__main__":
with open("{pickled_fn_path}", "rb") as f:
fn, args, kwargs = cloudpickle.load(f)
output = fn(*args, **kwargs)
with open("{fn_output_save_path}", "wb") as f:
cloudpickle.dump(output, f)
"""
)
return code_snippet
def _are_two_files_identical(self, fpath1: str, fpath2: str) -> bool:
with open(fpath1, "rb") as f:
contents_one = f.read()
with open(fpath2, "rb") as f:
contents_two = f.read()
self.assertEqual(contents_one, contents_two)
return contents_one == contents_two
@parameterized.expand(
[
("Check if it creates the correct file with no prefix nor suffix", "", ""),
(
"Check if it creates the correct file with only prefix + body",
"print('hello before')\n",
"",
),
(
"Check if it creates the correct file with only suffix + body",
"",
"print('goodbye')",
),
(
"Check if it creates the correct file prefix, body, and suffix",
"print('hello')\n",
"print('goodbye')\n",
),
]
)
def test_create_fn_run_script(self, mesg: str, prefix_test: str, suffix_test: str):
arg1, arg2 = 3, 4
pickled_fn_path = FunctionPickler.pickle_fn_and_save(
TestFunctionPickler._test_function, "", "", arg1, arg2
)
fn_out_path = "output.pickled"
reference_path = "ref_result_file.py"
test_path = "test_result.py"
body_for_reference = self._create_code_snippet_body(pickled_fn_path, fn_out_path)
with self.create_reference_file(
body_for_reference, prefix=prefix_test, suffix=suffix_test, fname=reference_path
) as _:
executable_file_path = FunctionPickler.create_fn_run_script(
pickled_fn_path,
fn_out_path,
test_path,
prefix_code=prefix_test,
suffix_code=suffix_test,
)
self.assertTrue(self._are_two_files_identical(reference_path, executable_file_path))
os.remove(executable_file_path)
os.remove(pickled_fn_path)
if __name__ == "__main__":
from pyspark.ml.tests.test_dl_util import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)