| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.arrow.vector.ipc; |
| |
| import java.io.IOException; |
| import java.nio.channels.WritableByteChannel; |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Optional; |
| import java.util.Set; |
| import org.apache.arrow.vector.FieldVector; |
| import org.apache.arrow.vector.VectorSchemaRoot; |
| import org.apache.arrow.vector.VectorUnloader; |
| import org.apache.arrow.vector.compression.CompressionCodec; |
| import org.apache.arrow.vector.compression.CompressionUtil; |
| import org.apache.arrow.vector.compression.NoCompressionCodec; |
| import org.apache.arrow.vector.dictionary.Dictionary; |
| import org.apache.arrow.vector.dictionary.DictionaryProvider; |
| import org.apache.arrow.vector.ipc.message.ArrowBlock; |
| import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; |
| import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; |
| import org.apache.arrow.vector.ipc.message.IpcOption; |
| import org.apache.arrow.vector.ipc.message.MessageSerializer; |
| import org.apache.arrow.vector.types.pojo.Field; |
| import org.apache.arrow.vector.types.pojo.Schema; |
| import org.apache.arrow.vector.util.DictionaryUtility; |
| import org.apache.arrow.vector.validate.MetadataV4UnionChecker; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| /** Abstract base class for implementing Arrow writers for IPC over a WriteChannel. */ |
| public abstract class ArrowWriter implements AutoCloseable { |
| |
| protected static final Logger LOGGER = LoggerFactory.getLogger(ArrowWriter.class); |
| |
| // schema with fields in message format, not memory format |
| protected final Schema schema; |
| protected final WriteChannel out; |
| |
| private final VectorUnloader unloader; |
| private final DictionaryProvider dictionaryProvider; |
| private final Set<Long> dictionaryIdsUsed = new HashSet<>(); |
| |
| private final CompressionCodec.Factory compressionFactory; |
| private final CompressionUtil.CodecType codecType; |
| private final Optional<Integer> compressionLevel; |
| private boolean started = false; |
| private boolean ended = false; |
| |
| private final CompressionCodec codec; |
| |
| protected IpcOption option; |
| |
| protected ArrowWriter( |
| VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { |
| this(root, provider, out, IpcOption.DEFAULT); |
| } |
| |
| protected ArrowWriter( |
| VectorSchemaRoot root, |
| DictionaryProvider provider, |
| WritableByteChannel out, |
| IpcOption option) { |
| this( |
| root, |
| provider, |
| out, |
| option, |
| NoCompressionCodec.Factory.INSTANCE, |
| CompressionUtil.CodecType.NO_COMPRESSION, |
| Optional.empty()); |
| } |
| |
| /** |
| * Note: fields are not closed when the writer is closed. |
| * |
| * @param root the vectors to write to the output |
| * @param provider where to find the dictionaries |
| * @param out the output where to write |
| * @param option IPC write options |
| * @param compressionFactory Compression codec factory |
| * @param codecType Compression codec |
| * @param compressionLevel Compression level |
| */ |
| protected ArrowWriter( |
| VectorSchemaRoot root, |
| DictionaryProvider provider, |
| WritableByteChannel out, |
| IpcOption option, |
| CompressionCodec.Factory compressionFactory, |
| CompressionUtil.CodecType codecType, |
| Optional<Integer> compressionLevel) { |
| this.out = new WriteChannel(out); |
| this.option = option; |
| this.dictionaryProvider = provider; |
| |
| this.compressionFactory = compressionFactory; |
| this.codecType = codecType; |
| this.compressionLevel = compressionLevel; |
| this.codec = |
| this.compressionLevel.isPresent() |
| ? this.compressionFactory.createCodec(this.codecType, this.compressionLevel.get()) |
| : this.compressionFactory.createCodec(this.codecType); |
| this.unloader = |
| new VectorUnloader(root, /*includeNullCount*/ true, codec, /*alignBuffers*/ true); |
| |
| List<Field> fields = new ArrayList<>(root.getSchema().getFields().size()); |
| |
| MetadataV4UnionChecker.checkForUnion( |
| root.getSchema().getFields().iterator(), option.metadataVersion); |
| // Convert fields with dictionaries to have dictionary type |
| for (Field field : root.getSchema().getFields()) { |
| fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIdsUsed)); |
| } |
| |
| this.schema = new Schema(fields, root.getSchema().getCustomMetadata()); |
| } |
| |
| public void start() throws IOException { |
| ensureStarted(); |
| } |
| |
| /** Writes the record batch currently loaded in this instance's VectorSchemaRoot. */ |
| public void writeBatch() throws IOException { |
| ensureStarted(); |
| ensureDictionariesWritten(dictionaryProvider, dictionaryIdsUsed); |
| try (ArrowRecordBatch batch = unloader.getRecordBatch()) { |
| writeRecordBatch(batch); |
| } |
| } |
| |
| protected void writeDictionaryBatch(Dictionary dictionary) throws IOException { |
| FieldVector vector = dictionary.getVector(); |
| long id = dictionary.getEncoding().getId(); |
| int count = vector.getValueCount(); |
| VectorSchemaRoot dictRoot = |
| new VectorSchemaRoot( |
| Collections.singletonList(vector.getField()), Collections.singletonList(vector), count); |
| VectorUnloader unloader = |
| new VectorUnloader(dictRoot, /*includeNullCount*/ true, this.codec, /*alignBuffers*/ true); |
| ArrowRecordBatch batch = unloader.getRecordBatch(); |
| ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false); |
| try { |
| writeDictionaryBatch(dictionaryBatch); |
| } finally { |
| try { |
| dictionaryBatch.close(); |
| } catch (Exception e) { |
| throw new RuntimeException("Error occurred while closing dictionary.", e); |
| } |
| } |
| } |
| |
| protected ArrowBlock writeDictionaryBatch(ArrowDictionaryBatch batch) throws IOException { |
| ArrowBlock block = MessageSerializer.serialize(out, batch, option); |
| if (LOGGER.isDebugEnabled()) { |
| LOGGER.debug( |
| "DictionaryRecordBatch at {}, metadata: {}, body: {}", |
| block.getOffset(), |
| block.getMetadataLength(), |
| block.getBodyLength()); |
| } |
| return block; |
| } |
| |
| protected ArrowBlock writeRecordBatch(ArrowRecordBatch batch) throws IOException { |
| ArrowBlock block = MessageSerializer.serialize(out, batch, option); |
| if (LOGGER.isDebugEnabled()) { |
| LOGGER.debug( |
| "RecordBatch at {}, metadata: {}, body: {}", |
| block.getOffset(), |
| block.getMetadataLength(), |
| block.getBodyLength()); |
| } |
| return block; |
| } |
| |
| public void end() throws IOException { |
| ensureStarted(); |
| ensureEnded(); |
| } |
| |
| public long bytesWritten() { |
| return out.getCurrentPosition(); |
| } |
| |
| private void ensureStarted() throws IOException { |
| if (!started) { |
| started = true; |
| startInternal(out); |
| // write the schema - for file formats this is duplicated in the footer, but matches |
| // the streaming format |
| MessageSerializer.serialize(out, schema, option); |
| } |
| } |
| |
| /** |
| * Write dictionaries after schema and before recordBatches, dictionaries won't be written if |
| * empty stream (only has schema data in IPC). |
| */ |
| protected abstract void ensureDictionariesWritten( |
| DictionaryProvider provider, Set<Long> dictionaryIdsUsed) throws IOException; |
| |
| private void ensureEnded() throws IOException { |
| if (!ended) { |
| ended = true; |
| endInternal(out); |
| } |
| } |
| |
| protected void startInternal(WriteChannel out) throws IOException {} |
| |
| protected void endInternal(WriteChannel out) throws IOException {} |
| |
| @Override |
| public void close() { |
| try { |
| end(); |
| out.close(); |
| } catch (Exception e) { |
| throw new RuntimeException(e); |
| } |
| } |
| } |