blob: 36eca369a4c3c8f9601556092c46692006f13785 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.ml.python;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.ListTypeInfo;
import org.apache.flink.api.java.typeutils.MapTypeInfo;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
import org.apache.flink.api.python.shaded.net.razorvine.pickle.Pickler;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.streaming.api.typeinfo.python.PickledByteArrayTypeInfo;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;
import java.io.IOException;
import java.sql.Date;
import java.sql.Time;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.FLOAT_TYPE_INFO;
import static org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo.DATE;
import static org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo.TIME;
/**
* Utility functions used to override PyFlink methods to provide a temporary solution for certain
* bugs.
*/
// TODO: Remove this class after Flink ML depends on a Flink version with FLINK-30168 and
// FLINK-29477 fixed.
public class PythonBridgeUtils {
public static Object getPickledBytesFromJavaObject(Object obj, TypeInformation<?> dataType)
throws IOException {
Pickler pickler = new Pickler();
// triggers the initialization process
org.apache.flink.api.common.python.PythonBridgeUtils.getPickledBytesFromJavaObject(
null, null);
if (obj == null) {
return pickler.dumps(null);
} else {
if (dataType instanceof SqlTimeTypeInfo) {
SqlTimeTypeInfo<?> sqlTimeTypeInfo =
SqlTimeTypeInfo.getInfoFor(dataType.getTypeClass());
if (sqlTimeTypeInfo == DATE) {
return pickler.dumps(((Date) obj).toLocalDate().toEpochDay());
} else if (sqlTimeTypeInfo == TIME) {
return pickler.dumps(((Time) obj).toLocalTime().toNanoOfDay() / 1000);
}
} else if (dataType instanceof RowTypeInfo || dataType instanceof TupleTypeInfo) {
TypeInformation<?>[] fieldTypes = ((TupleTypeInfoBase<?>) dataType).getFieldTypes();
int arity =
dataType instanceof RowTypeInfo
? ((Row) obj).getArity()
: ((Tuple) obj).getArity();
List<Object> fieldBytes = new ArrayList<>(arity + 1);
if (dataType instanceof RowTypeInfo) {
fieldBytes.add(new byte[] {((Row) obj).getKind().toByteValue()});
}
for (int i = 0; i < arity; i++) {
Object field =
dataType instanceof RowTypeInfo
? ((Row) obj).getField(i)
: ((Tuple) obj).getField(i);
fieldBytes.add(getPickledBytesFromJavaObject(field, fieldTypes[i]));
}
return fieldBytes;
} else if (dataType instanceof BasicArrayTypeInfo
|| dataType instanceof PrimitiveArrayTypeInfo
|| dataType instanceof ObjectArrayTypeInfo) {
Object[] objects;
TypeInformation<?> elementType;
if (dataType instanceof BasicArrayTypeInfo) {
objects = (Object[]) obj;
elementType = ((BasicArrayTypeInfo<?, ?>) dataType).getComponentInfo();
} else if (dataType instanceof PrimitiveArrayTypeInfo) {
objects = primitiveArrayConverter(obj, dataType);
elementType = ((PrimitiveArrayTypeInfo<?>) dataType).getComponentType();
} else {
objects = (Object[]) obj;
elementType = ((ObjectArrayTypeInfo<?, ?>) dataType).getComponentInfo();
}
List<Object> serializedElements = new ArrayList<>(objects.length);
for (Object object : objects) {
serializedElements.add(getPickledBytesFromJavaObject(object, elementType));
}
return pickler.dumps(serializedElements);
} else if (dataType instanceof MapTypeInfo) {
List<List<Object>> serializedMapKV = new ArrayList<>(2);
Map<Object, Object> mapObj = (Map) obj;
List<Object> keyBytesList = new ArrayList<>(mapObj.size());
List<Object> valueBytesList = new ArrayList<>(mapObj.size());
for (Map.Entry entry : mapObj.entrySet()) {
keyBytesList.add(
getPickledBytesFromJavaObject(
entry.getKey(), ((MapTypeInfo) dataType).getKeyTypeInfo()));
valueBytesList.add(
getPickledBytesFromJavaObject(
entry.getValue(), ((MapTypeInfo) dataType).getValueTypeInfo()));
}
serializedMapKV.add(keyBytesList);
serializedMapKV.add(valueBytesList);
return pickler.dumps(serializedMapKV);
} else if (dataType instanceof ListTypeInfo) {
List objects = (List) obj;
List<Object> serializedElements = new ArrayList<>(objects.size());
TypeInformation elementType = ((ListTypeInfo) dataType).getElementTypeInfo();
for (Object object : objects) {
serializedElements.add(getPickledBytesFromJavaObject(object, elementType));
}
return pickler.dumps(serializedElements);
}
if (dataType instanceof BasicTypeInfo
&& BasicTypeInfo.getInfoFor(dataType.getTypeClass()) == FLOAT_TYPE_INFO) {
// Serialization of float type with pickler loses precision.
return pickler.dumps(String.valueOf(obj));
} else if (dataType instanceof PickledByteArrayTypeInfo
|| dataType instanceof BasicTypeInfo) {
return pickler.dumps(obj);
} else {
// other typeinfos will use the corresponding serializer to serialize data.
TypeSerializer serializer = dataType.createSerializer(null);
ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos();
DataOutputViewStreamWrapper baosWrapper = new DataOutputViewStreamWrapper(baos);
serializer.serialize(obj, baosWrapper);
return pickler.dumps(baos.toByteArray());
}
}
}
private static Object[] primitiveArrayConverter(
Object array, TypeInformation<?> arrayTypeInfo) {
Preconditions.checkArgument(arrayTypeInfo instanceof PrimitiveArrayTypeInfo);
Preconditions.checkArgument(array.getClass().isArray());
Object[] objects;
if (PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
boolean[] booleans = (boolean[]) array;
objects = new Object[booleans.length];
for (int i = 0; i < booleans.length; i++) {
objects[i] = booleans[i];
}
} else if (PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
byte[] bytes = (byte[]) array;
objects = new Object[bytes.length];
for (int i = 0; i < bytes.length; i++) {
objects[i] = bytes[i];
}
} else if (PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
short[] shorts = (short[]) array;
objects = new Object[shorts.length];
for (int i = 0; i < shorts.length; i++) {
objects[i] = shorts[i];
}
} else if (PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
int[] ints = (int[]) array;
objects = new Object[ints.length];
for (int i = 0; i < ints.length; i++) {
objects[i] = ints[i];
}
} else if (PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
long[] longs = (long[]) array;
objects = new Object[longs.length];
for (int i = 0; i < longs.length; i++) {
objects[i] = longs[i];
}
} else if (PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
float[] floats = (float[]) array;
objects = new Object[floats.length];
for (int i = 0; i < floats.length; i++) {
objects[i] = floats[i];
}
} else if (PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
double[] doubles = (double[]) array;
objects = new Object[doubles.length];
for (int i = 0; i < doubles.length; i++) {
objects[i] = doubles[i];
}
} else if (PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
char[] chars = (char[]) array;
objects = new Object[chars.length];
for (int i = 0; i < chars.length; i++) {
objects[i] = chars[i];
}
} else {
throw new UnsupportedOperationException(
String.format(
"Primitive array of %s is not supported in PyFlink yet",
((PrimitiveArrayTypeInfo<?>) arrayTypeInfo)
.getComponentType()
.getTypeClass()
.getSimpleName()));
}
return objects;
}
}