| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| import os |
| import tempfile |
| import textwrap |
| from typing import Any, Callable |
| |
| from pyspark import cloudpickle |
| |
| |
| class FunctionPickler: |
| """ |
| This class provides a way to pickle a function and its arguments. |
| It also provides a way to create a script that can run a |
| function with arguments if they have them pickled to a file. |
| It also provides a way of extracting the contents of a pickle file. |
| """ |
| |
| @staticmethod |
| def pickle_fn_and_save( |
| fn: Callable, file_path: str, save_dir: str, *args: Any, **kwargs: Any |
| ) -> str: |
| """ |
| Given a function and args, this function will pickle them to a file. |
| |
| Parameters |
| ---------- |
| fn: Callable |
| The picklable function that will be pickled to a file. |
| file_path: str |
| The path where to save the pickled function, args, and kwargs. If it's the |
| empty string, the function will decide on a random name. |
| save_dir: str |
| The directory in which to save the file with the pickled function and arguments. |
| Does nothing if the path is specified. If both file_path and save_dir are empty, |
| the function will write the file to the current working directory with a random |
| name. |
| *args: Any |
| Arguments of fn that will be pickled. |
| **kwargs: Any |
| Key word arguments to fn that will be pickled. |
| |
| Returns |
| ------- |
| str |
| The path to the file where the function and arguments are pickled. |
| """ |
| if file_path != "": |
| with open(file_path, "wb") as f: |
| cloudpickle.dump((fn, args, kwargs), f) |
| return f.name |
| |
| if save_dir == "": |
| save_dir = os.getcwd() |
| |
| with tempfile.NamedTemporaryFile(dir=save_dir, delete=False) as f: |
| cloudpickle.dump((fn, args, kwargs), f) |
| return f.name |
| |
| @staticmethod |
| def create_fn_run_script( |
| pickled_fn_path: str, |
| fn_output_path: str, |
| script_path: str, |
| prefix_code: str = "", |
| suffix_code: str = "", |
| ) -> str: |
| """ |
| Given a file containing a pickled function and arguments, this function will create a |
| pytorch file that will execute the function and pickle the functions outputs. |
| |
| Parameters |
| ---------- |
| pickled_fn_path: str |
| This is the path of the file containing the pickled function, args, and kwargs. |
| fn_output_path: str |
| This is the location where the created file will save the pickled output of |
| the function. |
| script_path: str |
| This is the path which will be used for the created pytorch file. |
| prefix_code: str |
| This contains a string that the user can pass in which will be executed before |
| the code generated by this class to execute the function and save it. If |
| prefix_code is the empty string, nothing will be written before the auto- |
| generated code. |
| suffix_code: str |
| This contains a string of code that the user can pass in which will be executed |
| after the code generated by this class finishes executing. If suffix_code is |
| the empty string, nothing will be written after the auto-generated code. |
| |
| Returns |
| ------- |
| str |
| The path to the location of the newly created pytorch file. |
| """ |
| |
| code_snippet = textwrap.dedent( |
| f""" |
| from pyspark import cloudpickle |
| import os |
| |
| if __name__ == "__main__": |
| with open("{pickled_fn_path}", "rb") as f: |
| fn, args, kwargs = cloudpickle.load(f) |
| output = fn(*args, **kwargs) |
| with open("{fn_output_path}", "wb") as f: |
| cloudpickle.dump(output, f) |
| """ |
| ) |
| with open(script_path, "w") as f: |
| if prefix_code != "": |
| f.write(prefix_code) |
| f.write(code_snippet) |
| if suffix_code != "": |
| f.write(suffix_code) |
| |
| return script_path |
| |
| @staticmethod |
| def get_fn_output(fn_output_path: str) -> Any: |
| """ |
| Given a path to a file with pickled output, this function |
| will unpickle the output and return it to the user. |
| |
| Parameters |
| ---------- |
| fn_output_path: str |
| The path to the file containing the pickled output of a function. |
| |
| Returns |
| ------- |
| Any |
| The unpickled output stored in func_output_path |
| """ |
| with open(fn_output_path, "rb") as f: |
| return cloudpickle.load(f) |