blob: e614084711183ff54c6431cc779cebefac2fe3f8 [file] [log] [blame]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 1a6f4ac..42d555b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -26,7 +26,7 @@ import org.apache.arrow.flatbuf.MessageHeader
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel}
-import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer}
+import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer}
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
@@ -64,7 +64,7 @@ private[sql] class ArrowBatchStreamWriter(
* End the Arrow stream, does not close output stream.
*/
def end(): Unit = {
- ArrowStreamWriter.writeEndOfStream(writeChannel)
+ ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption)
}
}
@@ -252,7 +252,7 @@ private[sql] object ArrowConverters {
if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) {
// Buffer backed output large enough to hold the complete serialized message
- val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength)
+ val bbout = new ByteBufferOutputStream(8 + msgMetadata.getMessageLength + bodyLength)
// Write message metadata to ByteBuffer output stream
MessageSerializer.writeMessageBuffer(