blob: 29f809f1c47517a187f45aa1fdeca47695601256 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.arrow.serializers;
import org.apache.flink.api.python.shaded.org.apache.arrow.memory.BufferAllocator;
import org.apache.flink.api.python.shaded.org.apache.arrow.memory.RootAllocator;
import org.apache.flink.api.python.shaded.org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.flink.api.python.shaded.org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.flink.api.python.shaded.org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.arrow.ArrowReader;
import org.apache.flink.table.runtime.arrow.ArrowUtils;
import org.apache.flink.table.runtime.arrow.ArrowWriter;
import org.apache.flink.table.types.logical.RowType;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
/**
* this code is copied from flink-python, and modified finishCurrentBatch to add end operation.
*
* <p>The base class ArrowSerializer which will serialize/deserialize RowType data to/from arrow
* bytes.
*/
public final class ArrowSerializer {
static {
ArrowUtils.checkArrowUsable();
}
/** The input RowType. */
protected final RowType inputType;
/** The output RowType. */
protected final RowType outputType;
/** Allocator which is used for byte buffer allocation. */
private transient BufferAllocator allocator;
/** Reader which is responsible for deserialize the Arrow format data to the Flink rows. */
private transient ArrowReader arrowReader;
/**
* Reader which is responsible for convert the execution result from byte array to arrow format.
*/
private transient ArrowStreamReader arrowStreamReader;
/**
* Container that holds a set of vectors for the input elements to be sent to the Python worker.
*/
transient VectorSchemaRoot rootWriter;
/** Writer which is responsible for serialize the input elements to arrow format. */
private transient ArrowWriter<RowData> arrowWriter;
/** Writer which is responsible for convert the arrow format data into byte array. */
private transient ArrowStreamWriter arrowStreamWriter;
/** Reusable InputStream used to holding the execution results to be deserialized. */
private transient InputStream bais;
/** Reusable OutputStream used to holding the serialized input elements. */
private transient OutputStream baos;
public ArrowSerializer(RowType inputType, RowType outputType) {
this.inputType = inputType;
this.outputType = outputType;
}
public void open(InputStream bais, OutputStream baos) throws Exception {
this.bais = bais;
this.baos = baos;
allocator = new RootAllocator(Long.MAX_VALUE);
arrowStreamReader = new ArrowStreamReader(bais, allocator);
rootWriter = VectorSchemaRoot.create(ArrowUtils.toArrowSchema(inputType), allocator);
arrowWriter = createArrowWriter();
arrowStreamWriter = new ArrowStreamWriter(rootWriter, null, baos);
arrowStreamWriter.start();
}
public int load() throws IOException {
arrowStreamReader.loadNextBatch();
VectorSchemaRoot root = arrowStreamReader.getVectorSchemaRoot();
if (arrowReader == null) {
arrowReader = createArrowReader(root);
}
return root.getRowCount();
}
public RowData read(int i) {
return arrowReader.read(i);
}
public void write(RowData element) {
arrowWriter.write(element);
}
public void close() throws Exception {
arrowStreamWriter.end();
arrowStreamReader.close();
rootWriter.close();
allocator.close();
}
/** Creates an {@link ArrowWriter}. */
public ArrowWriter<RowData> createArrowWriter() {
return ArrowUtils.createRowDataArrowWriter(rootWriter, inputType);
}
public ArrowReader createArrowReader(VectorSchemaRoot root) {
return ArrowUtils.createArrowReader(root, outputType);
}
/**
* Forces to finish the processing of the current batch of elements. It will serialize the batch
* of elements into one arrow batch.
*/
public void finishCurrentBatch() throws Exception {
arrowWriter.finish();
arrowStreamWriter.writeBatch();
arrowStreamWriter.end();
arrowWriter.reset();
}
public void resetReader() throws IOException {
arrowReader = null;
arrowStreamReader.close();
arrowStreamReader = new ArrowStreamReader(bais, allocator);
}
public void resetWriter() throws IOException {
arrowStreamWriter = new ArrowStreamWriter(rootWriter, null, baos);
arrowStreamWriter.start();
}
}