[FLINK-33118] Remove the PythonBridgeUtils
This closes #256.
diff --git a/flink-ml-dist/pom.xml b/flink-ml-dist/pom.xml
index 7c260ec..1bb239d 100644
--- a/flink-ml-dist/pom.xml
+++ b/flink-ml-dist/pom.xml
@@ -46,12 +46,6 @@
<version>${project.version}</version>
</dependency>
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-ml-python-${flink.main.version}</artifactId>
- <version>${project.version}</version>
- </dependency>
-
<!-- Stateful Functions Dependencies -->
<dependency>
diff --git a/flink-ml-dist/src/main/assemblies/bin.xml b/flink-ml-dist/src/main/assemblies/bin.xml
index 00ad9dc..d260e0d 100644
--- a/flink-ml-dist/src/main/assemblies/bin.xml
+++ b/flink-ml-dist/src/main/assemblies/bin.xml
@@ -37,7 +37,6 @@
<include>org.apache.flink:statefun-flink-core</include>
<include>org.apache.flink:flink-ml-uber-${flink.main.version}</include>
<include>org.apache.flink:flink-ml-examples-${flink.main.version}</include>
- <include>org.apache.flink:flink-ml-python-${flink.main.version}</include>
</includes>
</dependencySet>
</dependencySets>
diff --git a/flink-ml-python/pom.xml b/flink-ml-python/pom.xml
index 34e7e5a..c481b21 100644
--- a/flink-ml-python/pom.xml
+++ b/flink-ml-python/pom.xml
@@ -32,20 +32,6 @@
<dependencies>
<dependency>
<groupId>org.apache.flink</groupId>
- <artifactId>flink-table-common</artifactId>
- <version>${flink.version}</version>
- <scope>provided</scope>
- </dependency>
-
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>${flink.python.artifact}</artifactId>
- <version>${flink.version}</version>
- <scope>provided</scope>
- </dependency>
-
- <dependency>
- <groupId>org.apache.flink</groupId>
<artifactId>flink-runtime</artifactId>
<version>${flink.version}</version>
<type>test-jar</type>
diff --git a/flink-ml-python/pyflink/ml/feature/tests/test_indextostringmodel.py b/flink-ml-python/pyflink/ml/feature/tests/test_indextostringmodel.py
index cae0673..34779b6 100644
--- a/flink-ml-python/pyflink/ml/feature/tests/test_indextostringmodel.py
+++ b/flink-ml-python/pyflink/ml/feature/tests/test_indextostringmodel.py
@@ -45,8 +45,8 @@
))
self.expected_prediction = [
- Row(0, 3, 'a', '2.0'),
- Row(1, 2, 'b', '1.0'),
+ Row(input_col1=0, input_col2=3, output_col1='a', output_col2='2.0'),
+ Row(input_col1=1, input_col2=2, output_col1='b', output_col2='1.0'),
]
def test_output_schema(self):
diff --git a/flink-ml-python/pyflink/ml/feature/tests/test_minhashlsh.py b/flink-ml-python/pyflink/ml/feature/tests/test_minhashlsh.py
index 9fed97c..a345c81 100644
--- a/flink-ml-python/pyflink/ml/feature/tests/test_minhashlsh.py
+++ b/flink-ml-python/pyflink/ml/feature/tests/test_minhashlsh.py
@@ -179,8 +179,8 @@
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(1)
expected = [
- Row(0, 0.75),
- Row(1, 0.75),
+ Row(id=0, distCol=0.75),
+ Row(id=1, distCol=0.75),
]
model: MinHashLSHModel = lsh.fit(self.data)
@@ -198,8 +198,8 @@
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(1)
expected = [
- Row(0, 0.75),
- Row(1, 0.75),
+ Row(id=0, distCol=0.75),
+ Row(id=1, distCol=0.75),
]
model: MinHashLSHModel = lsh.fit(self.data)
@@ -230,6 +230,8 @@
Row(1, 5, .5),
Row(2, 5, .5)
]
+ for r in expected:
+ r.set_field_names(['datasetA.id', 'datasetB.id', 'distCol'])
output = model.approx_similarity_join(data_a, data_b, .6, "id")
actual_result = [r for r in self.t_env.to_data_stream(output).execute_and_collect()]
diff --git a/flink-ml-python/pyflink/ml/feature/tests/test_sqltransformer.py b/flink-ml-python/pyflink/ml/feature/tests/test_sqltransformer.py
index 85d05a4..52ff82b 100644
--- a/flink-ml-python/pyflink/ml/feature/tests/test_sqltransformer.py
+++ b/flink-ml-python/pyflink/ml/feature/tests/test_sqltransformer.py
@@ -34,8 +34,8 @@
['id', 'v1', 'v2'],
[Types.INT(), Types.DOUBLE(), Types.DOUBLE()])))
self.expected_output = [
- (0, 1.0, 3.0, 4.0, 3.0),
- (2, 2.0, 5.0, 7.0, 10.0)
+ Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0),
+ Row(id=2, v1=2.0, v2=5.0, v3=7.0, v4=10.0)
]
def test_param(self):
@@ -62,4 +62,4 @@
actual_output.sort(key=lambda x: x[0])
self.assertEqual(len(self.expected_output), len(actual_output))
for i in range(len(actual_output)):
- self.assertEqual(Row(*self.expected_output[i]), actual_output[i])
+ self.assertEqual(self.expected_output[i], actual_output[i])
diff --git a/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py b/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py
index aa61503..0809533 100644
--- a/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py
+++ b/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py
@@ -58,11 +58,11 @@
[Types.STRING(), Types.DOUBLE()])))
self.expected_alphabetic_asc_predict_data = [
- Row('a', 2.0, 0, 3),
- Row('b', 1.0, 1, 2),
- Row('e', 2.0, 4, 3),
- Row('f', None, 4, 4),
- Row(None, None, 4, 4),
+ Row(input_col1='a', input_col2=2.0, output_col1=0, output_col2=3),
+ Row(input_col1='b', input_col2=1.0, output_col1=1, output_col2=2),
+ Row(input_col1='e', input_col2=2.0, output_col1=4, output_col2=3),
+ Row(input_col1='f', input_col2=None, output_col1=4, output_col2=4),
+ Row(input_col1=None, input_col2=None, output_col1=4, output_col2=4),
]
def test_param(self):
@@ -122,11 +122,11 @@
.set_string_order_type("frequencyDesc")
expected_predict_data = [
- Row('a', 2.0, 1, 0),
- Row('b', 1.0, 0, 2),
- Row('e', 2.0, 2, 0),
- Row('f', None, 2, 2),
- Row(None, None, 2, 2),
+ Row(input_col1='a', input_col2=2.0, output_col1=1, output_col2=0),
+ Row(input_col1='b', input_col2=1.0, output_col1=0, output_col2=2),
+ Row(input_col1='e', input_col2=2.0, output_col1=2, output_col2=0),
+ Row(input_col1='f', input_col2=None, output_col1=2, output_col2=2),
+ Row(input_col1=None, input_col2=None, output_col1=2, output_col2=2),
]
output = string_indexer.fit(self.train_table).transform(self.predict_table)[0]
diff --git a/flink-ml-python/pyflink/ml/stats/tests/test_chisqtest.py b/flink-ml-python/pyflink/ml/stats/tests/test_chisqtest.py
index 290e468..1c4edc5 100644
--- a/flink-ml-python/pyflink/ml/stats/tests/test_chisqtest.py
+++ b/flink-ml-python/pyflink/ml/stats/tests/test_chisqtest.py
@@ -50,8 +50,8 @@
)
self.expected_output_data = [
- Row(0, 0.03419350755, 6, 13.61904761905),
- Row(1, 0.24220177737, 6, 7.94444444444)]
+ Row(featureIndex=0, pValue=0.03419350755, degreeOfFreedom=6, statistic=13.61904761905),
+ Row(featureIndex=1, pValue=0.24220177737, degreeOfFreedom=6, statistic=7.94444444444)]
def test_param(self):
chi_sq_test = ChiSqTest()
diff --git a/flink-ml-python/pyflink/ml/wrapper.py b/flink-ml-python/pyflink/ml/wrapper.py
index f30f218..50f0105 100644
--- a/flink-ml-python/pyflink/ml/wrapper.py
+++ b/flink-ml-python/pyflink/ml/wrapper.py
@@ -15,16 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-import pickle
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from py4j.java_gateway import JavaObject, get_java_class
-from pyflink.common import typeinfo, Time, Row, RowKind
-from pyflink.common.typeinfo import _from_java_type, TypeInformation, _is_instance_of, Types, \
- ExternalTypeInfo, RowTypeInfo, TupleTypeInfo
-from pyflink.datastream import utils
-from pyflink.datastream.utils import pickled_bytes_to_python_obj
+from pyflink.common import typeinfo, Time
+from pyflink.common.typeinfo import _from_java_type, TypeInformation, _is_instance_of
from pyflink.java_gateway import get_gateway
from pyflink.table import Table, StreamTableEnvironment, Expression
from pyflink.util.java_utils import to_jarray
@@ -60,36 +56,6 @@
typeinfo._from_java_type = _from_java_type_wrapper
-# TODO: Remove this class after Flink ML depends on a Flink version
-# with FLINK-30168 and FLINK-29477 fixed.
-def convert_to_python_obj_wrapper(data, type_info):
- if type_info == Types.PICKLED_BYTE_ARRAY():
- return pickle.loads(data)
- elif isinstance(type_info, ExternalTypeInfo):
- return convert_to_python_obj_wrapper(data, type_info._type_info)
- else:
- gateway = get_gateway()
- pickle_bytes = gateway.jvm.org.apache.flink.ml.python.PythonBridgeUtils. \
- getPickledBytesFromJavaObject(data, type_info.get_java_type_info())
- if isinstance(type_info, RowTypeInfo) or isinstance(type_info, TupleTypeInfo):
- field_data = zip(list(pickle_bytes[1:]), type_info.get_field_types())
- fields = []
- for data, field_type in field_data:
- if len(data) == 0:
- fields.append(None)
- else:
- fields.append(pickled_bytes_to_python_obj(data, field_type))
- if isinstance(type_info, RowTypeInfo):
- return Row.of_kind(RowKind(int.from_bytes(pickle_bytes[0], 'little')), *fields)
- else:
- return tuple(fields)
- else:
- return pickled_bytes_to_python_obj(pickle_bytes, type_info)
-
-
-utils.convert_to_python_obj = convert_to_python_obj_wrapper
-
-
class JavaWrapper(ABC):
"""
Wrapper class for a Java object.
diff --git a/flink-ml-python/src/main/java/org/apache/flink/ml/python/PythonBridgeUtils.java b/flink-ml-python/src/main/java/org/apache/flink/ml/python/PythonBridgeUtils.java
deleted file mode 100644
index 36eca36..0000000
--- a/flink-ml-python/src/main/java/org/apache/flink/ml/python/PythonBridgeUtils.java
+++ /dev/null
@@ -1,226 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.python;
-
-import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
-import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.api.java.tuple.Tuple;
-import org.apache.flink.api.java.typeutils.ListTypeInfo;
-import org.apache.flink.api.java.typeutils.MapTypeInfo;
-import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
-import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.api.java.typeutils.TupleTypeInfo;
-import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
-import org.apache.flink.api.python.shaded.net.razorvine.pickle.Pickler;
-import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
-import org.apache.flink.streaming.api.typeinfo.python.PickledByteArrayTypeInfo;
-import org.apache.flink.types.Row;
-import org.apache.flink.util.Preconditions;
-
-import java.io.IOException;
-import java.sql.Date;
-import java.sql.Time;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-
-import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.FLOAT_TYPE_INFO;
-import static org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo.DATE;
-import static org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo.TIME;
-
-/**
- * Utility functions used to override PyFlink methods to provide a temporary solution for certain
- * bugs.
- */
-// TODO: Remove this class after Flink ML depends on a Flink version with FLINK-30168 and
-// FLINK-29477 fixed.
-public class PythonBridgeUtils {
- public static Object getPickledBytesFromJavaObject(Object obj, TypeInformation<?> dataType)
- throws IOException {
- Pickler pickler = new Pickler();
-
- // triggers the initialization process
- org.apache.flink.api.common.python.PythonBridgeUtils.getPickledBytesFromJavaObject(
- null, null);
-
- if (obj == null) {
- return pickler.dumps(null);
- } else {
- if (dataType instanceof SqlTimeTypeInfo) {
- SqlTimeTypeInfo<?> sqlTimeTypeInfo =
- SqlTimeTypeInfo.getInfoFor(dataType.getTypeClass());
- if (sqlTimeTypeInfo == DATE) {
- return pickler.dumps(((Date) obj).toLocalDate().toEpochDay());
- } else if (sqlTimeTypeInfo == TIME) {
- return pickler.dumps(((Time) obj).toLocalTime().toNanoOfDay() / 1000);
- }
- } else if (dataType instanceof RowTypeInfo || dataType instanceof TupleTypeInfo) {
- TypeInformation<?>[] fieldTypes = ((TupleTypeInfoBase<?>) dataType).getFieldTypes();
- int arity =
- dataType instanceof RowTypeInfo
- ? ((Row) obj).getArity()
- : ((Tuple) obj).getArity();
-
- List<Object> fieldBytes = new ArrayList<>(arity + 1);
- if (dataType instanceof RowTypeInfo) {
- fieldBytes.add(new byte[] {((Row) obj).getKind().toByteValue()});
- }
- for (int i = 0; i < arity; i++) {
- Object field =
- dataType instanceof RowTypeInfo
- ? ((Row) obj).getField(i)
- : ((Tuple) obj).getField(i);
- fieldBytes.add(getPickledBytesFromJavaObject(field, fieldTypes[i]));
- }
- return fieldBytes;
- } else if (dataType instanceof BasicArrayTypeInfo
- || dataType instanceof PrimitiveArrayTypeInfo
- || dataType instanceof ObjectArrayTypeInfo) {
- Object[] objects;
- TypeInformation<?> elementType;
- if (dataType instanceof BasicArrayTypeInfo) {
- objects = (Object[]) obj;
- elementType = ((BasicArrayTypeInfo<?, ?>) dataType).getComponentInfo();
- } else if (dataType instanceof PrimitiveArrayTypeInfo) {
- objects = primitiveArrayConverter(obj, dataType);
- elementType = ((PrimitiveArrayTypeInfo<?>) dataType).getComponentType();
- } else {
- objects = (Object[]) obj;
- elementType = ((ObjectArrayTypeInfo<?, ?>) dataType).getComponentInfo();
- }
-
- List<Object> serializedElements = new ArrayList<>(objects.length);
-
- for (Object object : objects) {
- serializedElements.add(getPickledBytesFromJavaObject(object, elementType));
- }
- return pickler.dumps(serializedElements);
- } else if (dataType instanceof MapTypeInfo) {
- List<List<Object>> serializedMapKV = new ArrayList<>(2);
- Map<Object, Object> mapObj = (Map) obj;
- List<Object> keyBytesList = new ArrayList<>(mapObj.size());
- List<Object> valueBytesList = new ArrayList<>(mapObj.size());
- for (Map.Entry entry : mapObj.entrySet()) {
- keyBytesList.add(
- getPickledBytesFromJavaObject(
- entry.getKey(), ((MapTypeInfo) dataType).getKeyTypeInfo()));
- valueBytesList.add(
- getPickledBytesFromJavaObject(
- entry.getValue(), ((MapTypeInfo) dataType).getValueTypeInfo()));
- }
- serializedMapKV.add(keyBytesList);
- serializedMapKV.add(valueBytesList);
- return pickler.dumps(serializedMapKV);
- } else if (dataType instanceof ListTypeInfo) {
- List objects = (List) obj;
- List<Object> serializedElements = new ArrayList<>(objects.size());
- TypeInformation elementType = ((ListTypeInfo) dataType).getElementTypeInfo();
- for (Object object : objects) {
- serializedElements.add(getPickledBytesFromJavaObject(object, elementType));
- }
- return pickler.dumps(serializedElements);
- }
- if (dataType instanceof BasicTypeInfo
- && BasicTypeInfo.getInfoFor(dataType.getTypeClass()) == FLOAT_TYPE_INFO) {
- // Serialization of float type with pickler loses precision.
- return pickler.dumps(String.valueOf(obj));
- } else if (dataType instanceof PickledByteArrayTypeInfo
- || dataType instanceof BasicTypeInfo) {
- return pickler.dumps(obj);
- } else {
- // other typeinfos will use the corresponding serializer to serialize data.
- TypeSerializer serializer = dataType.createSerializer(null);
- ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos();
- DataOutputViewStreamWrapper baosWrapper = new DataOutputViewStreamWrapper(baos);
- serializer.serialize(obj, baosWrapper);
- return pickler.dumps(baos.toByteArray());
- }
- }
- }
-
- private static Object[] primitiveArrayConverter(
- Object array, TypeInformation<?> arrayTypeInfo) {
- Preconditions.checkArgument(arrayTypeInfo instanceof PrimitiveArrayTypeInfo);
- Preconditions.checkArgument(array.getClass().isArray());
- Object[] objects;
- if (PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
- boolean[] booleans = (boolean[]) array;
- objects = new Object[booleans.length];
- for (int i = 0; i < booleans.length; i++) {
- objects[i] = booleans[i];
- }
- } else if (PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
- byte[] bytes = (byte[]) array;
- objects = new Object[bytes.length];
- for (int i = 0; i < bytes.length; i++) {
- objects[i] = bytes[i];
- }
- } else if (PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
- short[] shorts = (short[]) array;
- objects = new Object[shorts.length];
- for (int i = 0; i < shorts.length; i++) {
- objects[i] = shorts[i];
- }
- } else if (PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
- int[] ints = (int[]) array;
- objects = new Object[ints.length];
- for (int i = 0; i < ints.length; i++) {
- objects[i] = ints[i];
- }
- } else if (PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
- long[] longs = (long[]) array;
- objects = new Object[longs.length];
- for (int i = 0; i < longs.length; i++) {
- objects[i] = longs[i];
- }
- } else if (PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
- float[] floats = (float[]) array;
- objects = new Object[floats.length];
- for (int i = 0; i < floats.length; i++) {
- objects[i] = floats[i];
- }
- } else if (PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
- double[] doubles = (double[]) array;
- objects = new Object[doubles.length];
- for (int i = 0; i < doubles.length; i++) {
- objects[i] = doubles[i];
- }
- } else if (PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
- char[] chars = (char[]) array;
- objects = new Object[chars.length];
- for (int i = 0; i < chars.length; i++) {
- objects[i] = chars[i];
- }
- } else {
- throw new UnsupportedOperationException(
- String.format(
- "Primitive array of %s is not supported in PyFlink yet",
- ((PrimitiveArrayTypeInfo<?>) arrayTypeInfo)
- .getComponentType()
- .getTypeClass()
- .getSimpleName()));
- }
- return objects;
- }
-}