blob: 3359ccc334ba4cf16f02759019d152e429f19e2f [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from py4j.java_gateway import JavaObject, get_java_class
from pyflink.common import typeinfo
from pyflink.common.typeinfo import _from_java_type, TypeInformation, _is_instance_of
from pyflink.java_gateway import get_gateway
from pyflink.table import Table, StreamTableEnvironment
from pyflink.util.java_utils import to_jarray
from pyflink.ml.core.api import Model, Transformer, AlgoOperator, Stage, Estimator
from pyflink.ml.core.linalg import DenseVectorTypeInfo, SparseVectorTypeInfo, DenseMatrixTypeInfo, \
VectorTypeInfo, DenseVector
from pyflink.ml.core.param import Param, WithParams, StringArrayParam, IntArrayParam, VectorParam, \
FloatArrayParam, FloatArrayArrayParam
_from_java_type_alias = _from_java_type
def _from_java_type_wrapper(j_type_info: JavaObject) -> TypeInformation:
gateway = get_gateway()
JGenericTypeInfo = gateway.jvm.org.apache.flink.api.java.typeutils.GenericTypeInfo
if _is_instance_of(j_type_info, JGenericTypeInfo):
JClass = j_type_info.getTypeClass()
if JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.DenseVector):
return DenseVectorTypeInfo()
elif JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.SparseVector):
return SparseVectorTypeInfo()
elif JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.DenseMatrix):
return DenseMatrixTypeInfo()
elif JClass == get_java_class(gateway.jvm.org.apache.flink.ml.linalg.Vector):
return VectorTypeInfo()
return _from_java_type_alias(j_type_info)
typeinfo._from_java_type = _from_java_type_wrapper
class JavaWrapper(ABC):
"""
Wrapper class for a Java object
"""
def __init__(self, java_obj):
self._java_obj = java_obj
class JavaWithParams(WithParams, JavaWrapper):
"""
Wrapper class for a Java WithParams.
"""
def __init__(self, java_params):
super(JavaWithParams, self).__init__(java_params)
def set(self, param: Param, value) -> WithParams:
if type(param) in _map_java_param_converter:
converter = _map_java_param_converter[type(param)]
else:
converter = default_converter
java_param_name = snake_to_camel(param.name)
set_method_name = ''.join(['set', java_param_name[0].upper(), java_param_name[1:]])
getattr(self._java_obj, set_method_name)(converter.to_java(value))
return self
def get(self, param: Param):
if type(param) in _map_java_param_converter:
converter = _map_java_param_converter[type(param)]
else:
converter = default_converter
java_param_name = snake_to_camel(param.name)
get_method_name = ''.join(['get', java_param_name[0].upper(), java_param_name[1:]])
return converter.to_python(getattr(self._java_obj, get_method_name)())
def get_param_map(self) -> Dict[Param, Any]:
return self._java_obj.getParamMap()
class JavaStage(Stage, JavaWithParams, ABC):
"""
Wrapper class for a Java Stage.
"""
def __init__(self, java_stage):
super(JavaStage, self).__init__(java_stage)
def save(self, path: str) -> None:
self._java_obj.save(path)
@classmethod
def load(cls, t_env: StreamTableEnvironment, path: str):
java_model = _to_java_reference(cls._java_stage_path()).load(t_env._j_tenv, path)
instance = cls(java_model)
return instance
@classmethod
@abstractmethod
def _java_stage_path(cls) -> str:
pass
class JavaAlgoOperator(AlgoOperator, JavaStage, ABC):
"""
Wrapper class for a Java AlgoOperator.
"""
def __init__(self, java_algo_operator):
if java_algo_operator is None:
super(JavaAlgoOperator, self).__init__(_to_java_reference(self._java_stage_path())())
else:
super(JavaAlgoOperator, self).__init__(java_algo_operator)
def transform(self, *inputs: Table) -> List[Table]:
results = self._java_obj.transform(_to_java_tables(*inputs))
return [Table(t, inputs[0]._t_env) for t in results]
class JavaTransformer(Transformer, JavaAlgoOperator, ABC):
"""
Wrapper class for a Java Transformer.
"""
def __init__(self, java_transformer):
super(JavaTransformer, self).__init__(java_transformer)
class JavaModel(Model, JavaTransformer, ABC):
"""
Wrapper class for a Java Model.
"""
def __init__(self, java_model):
super(JavaModel, self).__init__(java_model)
self._t_env = None
def set_model_data(self, *inputs: Table) -> Model:
self._t_env = inputs[0]._t_env
self._java_obj.setModelData(_to_java_tables(*inputs))
return self
def get_model_data(self) -> List[Table]:
return [Table(t, self._t_env) for t in self._java_obj.getModelData()]
class JavaEstimator(Estimator, JavaStage, ABC):
"""
Wrapper class for a Java Estimator.
"""
def __init__(self):
super(JavaEstimator, self).__init__(_new_java_obj(self._java_stage_path()))
def fit(self, *inputs: Table) -> Model:
return self._create_model(self._java_obj.fit(_to_java_tables(*inputs)))
@classmethod
def _create_model(cls, java_model) -> Model:
"""
Creates a model from the input Java model reference.
"""
pass
@classmethod
def load(cls, t_env: StreamTableEnvironment, path: str):
"""
Instantiates a new stage instance based on the data read from the given path.
"""
java_estimator = _to_java_reference(cls._java_stage_path()).load(t_env._j_tenv, path)
instance = cls()
instance._java_obj = java_estimator
return instance
class JavaParamConverter(ABC):
@abstractmethod
def to_java(self, value):
pass
@abstractmethod
def to_python(self, value):
pass
class DefaultJavaParamConverter(JavaParamConverter):
def to_java(self, value):
return value
def to_python(self, value):
return value
class IntArrayJavaPramConverter(JavaParamConverter):
def to_java(self, value):
return to_jarray(get_gateway().jvm.java.lang.Integer, value)
def to_python(self, value):
return tuple(value[i] for i in range(len(value)))
class FloatArrayJavaPramConverter(JavaParamConverter):
def to_java(self, value):
return to_jarray(get_gateway().jvm.java.lang.Double, value)
def to_python(self, value):
return tuple(value[i] for i in range(len(value)))
class VectorJavaParamConverter(JavaParamConverter):
def to_java(self, value):
jarray = to_jarray(get_gateway().jvm.double, value.to_array())
return get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(jarray)
def to_python(self, value):
return DenseVector(tuple(value.get(i) for i in range(value.size())))
class StringArrayJavaParamConverter(JavaParamConverter):
def to_java(self, value):
return to_jarray(get_gateway().jvm.java.lang.String, value)
def to_python(self, value):
return tuple(value[i] for i in range(len(value)))
class FloatArrayArrayJavaPramConverter(JavaParamConverter):
def to_java(self, value):
n = len(value)
m = len(value[0])
j_arr = get_gateway().new_array(get_gateway().jvm.java.lang.Double, n, m)
for i in range(n):
for j in range(m):
j_arr[i][j] = value[i][j]
return j_arr
def to_python(self, value):
n = len(value)
m = len(value[0])
arr = []
for i in range(n):
l = []
for j in range(m):
l.append(value[i][j])
arr.append(tuple(l))
return tuple(arr)
default_converter = DefaultJavaParamConverter()
_map_java_param_converter = {
IntArrayParam: IntArrayJavaPramConverter(),
FloatArrayParam: FloatArrayJavaPramConverter(),
FloatArrayArrayParam: FloatArrayArrayJavaPramConverter(),
StringArrayParam: StringArrayJavaParamConverter(),
VectorParam: VectorJavaParamConverter(),
Param: default_converter
}
def snake_to_camel(method_name):
output = ''.join(x.capitalize() or '_' for x in method_name.split('_'))
return output[0].lower() + output[1:]
def _to_java_reference(java_class: str):
java_obj = get_gateway().jvm
for name in java_class.split("."):
java_obj = getattr(java_obj, name)
return java_obj
def _new_java_obj(java_class: str, *java_args):
"""
Returns a new Java object.
"""
java_obj = _to_java_reference(java_class)
return java_obj(*java_args)
def _to_java_tables(*inputs: Table):
"""
Converts Python Tables to Java tables.
"""
gateway = get_gateway()
return to_jarray(gateway.jvm.org.apache.flink.table.api.Table, [t._j_table for t in inputs])