blob: 839f493011730456f5c67c07f92101848ca11dd0 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.apex.malhar.python.base.util;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.Serializer;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import jep.NDArray;
/***
* A handy Kryo serializer class that can be used to serialize and deserialize a JEP NDArray instance. It is
* recommended that {@link NDimensionalArray} be used in lieu of this class. This is because NDArray is highly specific
* JEP data structure and will not give flexibility if the python engines are changed in the future.
*/
public class NDArrayKryoSerializer extends Serializer<NDArray>
{
private static final short TRUE_AS_SHORTINT = 1;
private static final short FALSE_AS_SHORTINT = 0;
@Override
public void setGenerics(Kryo kryo, Class[] generics)
{
super.setGenerics(kryo, generics);
}
@Override
public void write(Kryo kryo, Output output, NDArray ndArray)
{
Object dataVal = ndArray.getData();
if (dataVal == null) {
return;
}
// NDArray throws an exception in constructor if not an array. So below value will never be null
Class classNameForArrayType = dataVal.getClass().getComponentType(); // null if it is not an array
int[] dimensions = ndArray.getDimensions();
boolean signedFlag = ndArray.isUnsigned();
int arraySizeInSingleDimension = 1;
for (int aDimensionSize : dimensions) {
arraySizeInSingleDimension = arraySizeInSingleDimension * aDimensionSize;
}
// write the single dimension length
output.writeInt(arraySizeInSingleDimension);
// write the dimension of the dimensions int array
output.writeInt(dimensions.length);
// next we write the dimensions array itself
output.writeInts(dimensions);
// next write the unsigned flag
output.writeBoolean(signedFlag);
// write the data type of the N-dimensional Array
if (classNameForArrayType != null) {
output.writeString(classNameForArrayType.getCanonicalName());
} else {
output.writeString(null);
}
// write the array contents
if (dataVal != null) {
switch (classNameForArrayType.getCanonicalName()) {
case "float":
output.writeFloats((float[])dataVal);
break;
case "int":
output.writeInts((int[])dataVal);
break;
case "double":
output.writeDoubles((double[])dataVal);
break;
case "long":
output.writeLongs((long[])dataVal);
break;
case "short":
output.writeShorts((short[])dataVal);
break;
case "byte":
output.writeBytes((byte[])dataVal);
break;
case "boolean":
boolean[] originalBoolArray = (boolean[])dataVal;
short[] convertedBoolArray = new short[originalBoolArray.length];
for (int i = 0; i < originalBoolArray.length; i++) {
if (originalBoolArray[i]) {
convertedBoolArray[i] = TRUE_AS_SHORTINT;
} else {
convertedBoolArray[i] = FALSE_AS_SHORTINT;
}
}
output.writeShorts(convertedBoolArray);
break;
default:
throw new RuntimeException("Unsupported NDArray type serialization object");
}
}
}
@Override
public NDArray read(Kryo kryo, Input input, Class<NDArray> aClass)
{
int singleDimensionArrayLength = input.readInt();
int lengthOfDimensionsArray = input.readInt();
int[] dimensions = input.readInts(lengthOfDimensionsArray);
boolean signedFlag = input.readBoolean();
String dataType = input.readString();
if ( dataType == null) {
return null;
}
switch (dataType) {
case "float":
NDArray<float[]> floatNDArray = new NDArray<>(
input.readFloats(singleDimensionArrayLength),signedFlag,dimensions);
return floatNDArray;
case "int":
NDArray<int[]> intNDArray = new NDArray<>(
input.readInts(singleDimensionArrayLength),signedFlag,dimensions);
return intNDArray;
case "double":
NDArray<double[]> doubleNDArray = new NDArray<>(
input.readDoubles(singleDimensionArrayLength),signedFlag,dimensions);
return doubleNDArray;
case "long":
NDArray<long[]> longNDArray = new NDArray<>(
input.readLongs(singleDimensionArrayLength),signedFlag,dimensions);
return longNDArray;
case "short":
NDArray<short[]> shortNDArray = new NDArray<>(
input.readShorts(singleDimensionArrayLength),signedFlag,dimensions);
return shortNDArray;
case "byte":
NDArray<byte[]> byteNDArray = new NDArray<>(
input.readBytes(singleDimensionArrayLength),signedFlag,dimensions);
return byteNDArray;
case "boolean":
short[] shortsArray = input.readShorts(singleDimensionArrayLength);
boolean[] boolsArray = new boolean[shortsArray.length];
for (int i = 0; i < shortsArray.length; i++) {
if (TRUE_AS_SHORTINT == shortsArray[i]) {
boolsArray[i] = true;
} else {
boolsArray[i] = false;
}
}
NDArray<boolean[]> booleanNDArray = new NDArray<>(boolsArray,signedFlag,dimensions);
return booleanNDArray;
default:
return null;
}
}
}