| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| import pickle |
| from abc import ABC, abstractmethod |
| from typing import List, Dict, Any |
| |
| from py4j.java_gateway import JavaObject, get_java_class |
| from pyflink.common import typeinfo, Time, Row, RowKind |
| from pyflink.common.typeinfo import _from_java_type, TypeInformation, _is_instance_of, Types, \ |
| ExternalTypeInfo, RowTypeInfo, TupleTypeInfo |
| from pyflink.datastream import utils |
| from pyflink.datastream.utils import pickled_bytes_to_python_converter |
| from pyflink.java_gateway import get_gateway |
| from pyflink.table import Table, StreamTableEnvironment, Expression |
| from pyflink.util.java_utils import to_jarray |
| |
| from pyflink.ml.api import Model, Transformer, AlgoOperator, Stage, Estimator |
| from pyflink.ml.linalg import DenseVectorTypeInfo, SparseVectorTypeInfo, DenseMatrixTypeInfo, \ |
| VectorTypeInfo, DenseVector |
| from pyflink.ml.param import Param, WithParams, StringArrayParam, IntArrayParam, VectorParam, \ |
| FloatArrayParam, FloatArrayArrayParam, WindowsParam |
| from pyflink.ml.common.window import GlobalWindows, CountTumblingWindows, \ |
| EventTimeTumblingWindows, ProcessingTimeTumblingWindows, EventTimeSessionWindows, \ |
| ProcessingTimeSessionWindows |
| |
| _from_java_type_alias = _from_java_type |
| |
| |
| def _from_java_type_wrapper(j_type_info: JavaObject) -> TypeInformation: |
| gateway = get_gateway() |
| JGenericTypeInfo = gateway.jvm.org.apache.flink.api.java.typeutils.GenericTypeInfo |
| if _is_instance_of(j_type_info, JGenericTypeInfo): |
| JClass = j_type_info.getTypeClass() |
| if JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.DenseVector): |
| return DenseVectorTypeInfo() |
| elif JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.SparseVector): |
| return SparseVectorTypeInfo() |
| elif JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.DenseMatrix): |
| return DenseMatrixTypeInfo() |
| elif JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.Vector): |
| return VectorTypeInfo() |
| return _from_java_type_alias(j_type_info) |
| |
| |
| typeinfo._from_java_type = _from_java_type_wrapper |
| |
| |
| # TODO: Remove this class after Flink ML depends on a Flink version |
| # with FLINK-30168 and FLINK-29477 fixed. |
| def convert_to_python_obj_wrapper(data, type_info): |
| if type_info == Types.PICKLED_BYTE_ARRAY(): |
| return pickle.loads(data) |
| elif isinstance(type_info, ExternalTypeInfo): |
| return convert_to_python_obj_wrapper(data, type_info._type_info) |
| else: |
| gateway = get_gateway() |
| pickle_bytes = gateway.jvm.org.apache.flink.ml.python.PythonBridgeUtils. \ |
| getPickledBytesFromJavaObject(data, type_info.get_java_type_info()) |
| if isinstance(type_info, RowTypeInfo) or isinstance(type_info, TupleTypeInfo): |
| field_data = zip(list(pickle_bytes[1:]), type_info.get_field_types()) |
| fields = [] |
| for data, field_type in field_data: |
| if len(data) == 0: |
| fields.append(None) |
| else: |
| fields.append(pickled_bytes_to_python_converter(data, field_type)) |
| if isinstance(type_info, RowTypeInfo): |
| return Row.of_kind(RowKind(int.from_bytes(pickle_bytes[0], 'little')), *fields) |
| else: |
| return tuple(fields) |
| else: |
| return pickled_bytes_to_python_converter(pickle_bytes, type_info) |
| |
| |
| utils.convert_to_python_obj = convert_to_python_obj_wrapper |
| |
| |
| class JavaWrapper(ABC): |
| """ |
| Wrapper class for a Java object. |
| """ |
| |
| def __init__(self, java_obj): |
| self._java_obj = java_obj |
| |
| |
| class JavaWithParams(WithParams, JavaWrapper): |
| """ |
| Wrapper class for a Java WithParams. |
| """ |
| |
| def __init__(self, java_params): |
| super(JavaWithParams, self).__init__(java_params) |
| |
| def set(self, param: Param, value) -> WithParams: |
| if type(param) in _map_java_param_converter: |
| converter = _map_java_param_converter[type(param)] |
| else: |
| converter = default_converter |
| java_param_name = snake_to_camel(param.name) |
| set_method_name = ''.join(['set', java_param_name[0].upper(), java_param_name[1:]]) |
| |
| gateway = get_gateway() |
| gateway.jvm.org.apache.flink.iteration.utils.ReflectionUtils.callMethod( |
| self._java_obj, |
| self._java_obj.getClass(), |
| set_method_name, |
| to_jarray(gateway.jvm.Object, [converter.to_java(value)]) |
| ) |
| return self |
| |
| def get(self, param: Param): |
| if type(param) in _map_java_param_converter: |
| converter = _map_java_param_converter[type(param)] |
| else: |
| converter = default_converter |
| java_param_name = snake_to_camel(param.name) |
| get_method_name = ''.join(['get', java_param_name[0].upper(), java_param_name[1:]]) |
| |
| gateway = get_gateway() |
| result = gateway.jvm.org.apache.flink.iteration.utils.ReflectionUtils.callMethod( |
| self._java_obj, |
| self._java_obj.getClass(), |
| get_method_name |
| ) |
| return converter.to_python(result) |
| |
| def get_param_map(self) -> Dict[Param, Any]: |
| return self._java_obj.getParamMap() |
| |
| |
| class JavaStage(Stage, JavaWithParams, ABC): |
| """ |
| Wrapper class for a Java Stage. |
| """ |
| |
| def __init__(self, java_stage): |
| super(JavaStage, self).__init__(java_stage) |
| |
| def save(self, path: str) -> None: |
| self._java_obj.save(path) |
| |
| @classmethod |
| def load(cls, t_env: StreamTableEnvironment, path: str): |
| java_model = _to_java_reference(cls._java_stage_path()).load(t_env._j_tenv, path) |
| instance = cls(java_model) |
| return instance |
| |
| @classmethod |
| @abstractmethod |
| def _java_stage_path(cls) -> str: |
| pass |
| |
| |
| class JavaAlgoOperator(AlgoOperator, JavaStage, ABC): |
| """ |
| Wrapper class for a Java AlgoOperator. |
| """ |
| |
| def __init__(self, java_algo_operator): |
| if java_algo_operator is None: |
| super(JavaAlgoOperator, self).__init__(_to_java_reference(self._java_stage_path())()) |
| else: |
| super(JavaAlgoOperator, self).__init__(java_algo_operator) |
| |
| def transform(self, *inputs: Table) -> List[Table]: |
| results = self._java_obj.transform(_to_java_tables(*inputs)) |
| return [Table(t, inputs[0]._t_env) for t in results] |
| |
| |
| class JavaTransformer(Transformer, JavaAlgoOperator, ABC): |
| """ |
| Wrapper class for a Java Transformer. |
| """ |
| |
| def __init__(self, java_transformer): |
| super(JavaTransformer, self).__init__(java_transformer) |
| |
| |
| class JavaModel(Model, JavaTransformer, ABC): |
| """ |
| Wrapper class for a Java Model. |
| """ |
| |
| def __init__(self, java_model): |
| super(JavaModel, self).__init__(java_model) |
| self._t_env = None |
| |
| def set_model_data(self, *inputs: Table) -> Model: |
| self._t_env = inputs[0]._t_env |
| self._java_obj.setModelData(_to_java_tables(*inputs)) |
| return self |
| |
| def get_model_data(self) -> List[Table]: |
| return [Table(t, self._t_env) for t in self._java_obj.getModelData()] |
| |
| |
| class JavaEstimator(Estimator, JavaStage, ABC): |
| """ |
| Wrapper class for a Java Estimator. |
| """ |
| |
| def __init__(self): |
| super(JavaEstimator, self).__init__(_new_java_obj(self._java_stage_path())) |
| |
| def fit(self, *inputs: Table) -> Model: |
| return self._create_model(self._java_obj.fit(_to_java_tables(*inputs))) |
| |
| @classmethod |
| def _create_model(cls, java_model) -> Model: |
| """ |
| Creates a model from the input Java model reference. |
| """ |
| pass |
| |
| @classmethod |
| def load(cls, t_env: StreamTableEnvironment, path: str): |
| """ |
| Instantiates a new stage instance based on the data read from the given path. |
| """ |
| java_estimator = _to_java_reference(cls._java_stage_path()).load(t_env._j_tenv, path) |
| instance = cls() |
| instance._java_obj = java_estimator |
| return instance |
| |
| |
| class JavaParamConverter(ABC): |
| @abstractmethod |
| def to_java(self, value): |
| pass |
| |
| @abstractmethod |
| def to_python(self, value): |
| pass |
| |
| |
| class DefaultJavaParamConverter(JavaParamConverter): |
| def to_java(self, value): |
| return value |
| |
| def to_python(self, value): |
| return value |
| |
| |
| class IntArrayJavaPramConverter(JavaParamConverter): |
| def to_java(self, value): |
| return to_jarray(get_gateway().jvm.java.lang.Integer, value) |
| |
| def to_python(self, value): |
| return tuple(value[i] for i in range(len(value))) |
| |
| |
| class FloatArrayJavaPramConverter(JavaParamConverter): |
| def to_java(self, value): |
| return to_jarray(get_gateway().jvm.java.lang.Double, value) |
| |
| def to_python(self, value): |
| return tuple(value[i] for i in range(len(value))) |
| |
| |
| class VectorJavaParamConverter(JavaParamConverter): |
| def to_java(self, value): |
| jarray = to_jarray(get_gateway().jvm.double, value.to_array()) |
| return get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(jarray) |
| |
| def to_python(self, value): |
| return DenseVector(tuple(value.get(i) for i in range(value.size()))) |
| |
| |
| class WindowsJavaParamConverter(JavaParamConverter): |
| @staticmethod |
| def _to_java_time(time: Time): |
| return get_gateway().jvm.org.apache.flink.api.common.time.Time.milliseconds( |
| time.to_milliseconds()) |
| |
| @staticmethod |
| def _to_python_time(time) -> Time: |
| return Time.milliseconds(time.toMilliseconds()) |
| |
| def to_java(self, value): |
| java_window_package = get_gateway().jvm.org.apache.flink.ml.common.window |
| if isinstance(value, GlobalWindows): |
| return java_window_package.GlobalWindows.getInstance() |
| elif isinstance(value, CountTumblingWindows): |
| return java_window_package.CountTumblingWindows.of(value.size) |
| elif isinstance(value, EventTimeTumblingWindows): |
| return java_window_package.EventTimeTumblingWindows.of( |
| WindowsJavaParamConverter._to_java_time(value.size)) |
| elif isinstance(value, ProcessingTimeTumblingWindows): |
| return java_window_package.ProcessingTimeTumblingWindows.of( |
| WindowsJavaParamConverter._to_java_time(value.size)) |
| elif isinstance(value, EventTimeSessionWindows): |
| return java_window_package.EventTimeSessionWindows.withGap( |
| WindowsJavaParamConverter._to_java_time(value.gap)) |
| elif isinstance(value, ProcessingTimeSessionWindows): |
| return java_window_package.ProcessingTimeSessionWindows.withGap( |
| WindowsJavaParamConverter._to_java_time(value.gap)) |
| else: |
| raise TypeError(f'Python object {str(value)}\' cannot be converted to Java object') |
| |
| def to_python(self, value): |
| if value.getClass().getName() == \ |
| "org.apache.flink.ml.common.window.GlobalWindows": |
| return GlobalWindows() |
| elif value.getClass().getName() == \ |
| "org.apache.flink.ml.common.window.CountTumblingWindows": |
| return CountTumblingWindows.of(value.getSize()) |
| elif value.getClass().getName() == \ |
| "org.apache.flink.ml.common.window.EventTimeTumblingWindows": |
| return EventTimeTumblingWindows.of( |
| WindowsJavaParamConverter._to_python_time(value.getSize())) |
| elif value.getClass().getName() == \ |
| "org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows": |
| return ProcessingTimeTumblingWindows.of( |
| WindowsJavaParamConverter._to_python_time(value.getSize())) |
| elif value.getClass().getName() == \ |
| "org.apache.flink.ml.common.window.EventTimeSessionWindows": |
| return EventTimeSessionWindows.with_gap( |
| WindowsJavaParamConverter._to_python_time(value.getGap())) |
| elif value.getClass().getName() == \ |
| "org.apache.flink.ml.common.window.ProcessingTimeSessionWindows": |
| return ProcessingTimeSessionWindows.with_gap( |
| WindowsJavaParamConverter._to_python_time(value.getGap())) |
| else: |
| raise TypeError(f'Java object {str(value)}\' cannot be converted to Python object') |
| |
| |
| class StringArrayJavaParamConverter(JavaParamConverter): |
| def to_java(self, value): |
| return to_jarray(get_gateway().jvm.java.lang.String, value) |
| |
| def to_python(self, value): |
| return tuple(value[i] for i in range(len(value))) |
| |
| |
| class FloatArrayArrayJavaPramConverter(JavaParamConverter): |
| def to_java(self, value): |
| n = len(value) |
| m = len(value[0]) |
| j_arr = get_gateway().new_array(get_gateway().jvm.java.lang.Double, n, m) |
| for i in range(n): |
| for j in range(m): |
| j_arr[i][j] = value[i][j] |
| return j_arr |
| |
| def to_python(self, value): |
| n = len(value) |
| m = len(value[0]) |
| arr = [] |
| for i in range(n): |
| l = [] |
| for j in range(m): |
| l.append(value[i][j]) |
| arr.append(tuple(l)) |
| return tuple(arr) |
| |
| |
| default_converter = DefaultJavaParamConverter() |
| |
| _map_java_param_converter = { |
| IntArrayParam: IntArrayJavaPramConverter(), |
| FloatArrayParam: FloatArrayJavaPramConverter(), |
| FloatArrayArrayParam: FloatArrayArrayJavaPramConverter(), |
| StringArrayParam: StringArrayJavaParamConverter(), |
| VectorParam: VectorJavaParamConverter(), |
| WindowsParam: WindowsJavaParamConverter(), |
| Param: default_converter |
| } |
| |
| |
| def snake_to_camel(method_name): |
| output = ''.join(x.capitalize() or '_' for x in method_name.split('_')) |
| return output[0].lower() + output[1:] |
| |
| |
| def _to_java_reference(java_class: str): |
| java_obj = get_gateway().jvm |
| for name in java_class.split("."): |
| java_obj = getattr(java_obj, name) |
| return java_obj |
| |
| |
| def _new_java_obj(java_class: str, *java_args): |
| """ |
| Returns a new Java object. |
| """ |
| java_obj = _to_java_reference(java_class) |
| return java_obj(*java_args) |
| |
| |
| def _to_java_tables(*inputs: Table): |
| """ |
| Converts Python Tables to Java tables. |
| """ |
| gateway = get_gateway() |
| return to_jarray(gateway.jvm.org.apache.flink.table.api.Table, [t._j_table for t in inputs]) |
| |
| |
| def call_java_table_function(java_table_function_name: str, *args): |
| _function = get_gateway().jvm |
| for member_name in java_table_function_name.split('.'): |
| _function = _function.__getattr__(member_name) |
| return Expression(_function(to_jarray( |
| get_gateway().jvm.java.lang.Object, |
| [expression._j_expr for expression in args]))) |