| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.orc.nohive; |
| |
| import org.apache.flink.api.common.serialization.BulkWriter; |
| import org.apache.flink.core.fs.FSDataOutputStream; |
| import org.apache.flink.orc.nohive.writer.NoHivePhysicalWriterImpl; |
| import org.apache.flink.table.data.RowData; |
| import org.apache.flink.table.types.logical.DecimalType; |
| import org.apache.flink.table.types.logical.LocalZonedTimestampType; |
| import org.apache.flink.table.types.logical.LogicalType; |
| import org.apache.flink.table.types.logical.TimestampType; |
| |
| import org.apache.hadoop.conf.Configuration; |
| import org.apache.hadoop.fs.Path; |
| import org.apache.orc.OrcFile; |
| import org.apache.orc.TypeDescription; |
| import org.apache.orc.impl.WriterImpl; |
| import org.apache.orc.storage.common.type.HiveDecimal; |
| import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; |
| import org.apache.orc.storage.ql.exec.vector.ColumnVector; |
| import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; |
| import org.apache.orc.storage.ql.exec.vector.DoubleColumnVector; |
| import org.apache.orc.storage.ql.exec.vector.LongColumnVector; |
| import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; |
| import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; |
| |
| import java.io.IOException; |
| import java.io.ObjectInputStream; |
| import java.io.ObjectOutputStream; |
| import java.sql.Timestamp; |
| import java.util.Properties; |
| |
| /** A {@link BulkWriter.Factory} from orc no-hive version. */ |
| public class OrcNoHiveBulkWriterFactory implements BulkWriter.Factory<RowData> { |
| |
| private Configuration conf; |
| private String schema; |
| private LogicalType[] fieldTypes; |
| |
| public OrcNoHiveBulkWriterFactory(Configuration conf, String schema, LogicalType[] fieldTypes) { |
| this.conf = conf; |
| this.schema = schema; |
| this.fieldTypes = fieldTypes; |
| } |
| |
| @Override |
| public BulkWriter<RowData> create(FSDataOutputStream out) throws IOException { |
| OrcFile.WriterOptions opts = OrcFile.writerOptions(new Properties(), conf); |
| TypeDescription description = TypeDescription.fromString(schema); |
| opts.setSchema(description); |
| opts.physicalWriter(new NoHivePhysicalWriterImpl(out, opts)); |
| WriterImpl writer = new WriterImpl(null, new Path("."), opts); |
| |
| VectorizedRowBatch rowBatch = description.createRowBatch(); |
| return new BulkWriter<RowData>() { |
| @Override |
| public void addElement(RowData row) throws IOException { |
| int rowId = rowBatch.size++; |
| for (int i = 0; i < row.getArity(); ++i) { |
| setColumn(rowId, rowBatch.cols[i], fieldTypes[i], row, i); |
| } |
| if (rowBatch.size == rowBatch.getMaxSize()) { |
| writer.addRowBatch(rowBatch); |
| rowBatch.reset(); |
| } |
| } |
| |
| @Override |
| public void flush() throws IOException { |
| if (rowBatch.size != 0) { |
| writer.addRowBatch(rowBatch); |
| rowBatch.reset(); |
| } |
| } |
| |
| @Override |
| public void finish() throws IOException { |
| flush(); |
| writer.close(); |
| } |
| }; |
| } |
| |
| // Custom serialization methods |
| private void writeObject(ObjectOutputStream out) throws IOException { |
| conf.write(out); |
| out.writeObject(schema); |
| out.writeObject(fieldTypes); |
| } |
| |
| private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { |
| conf = new Configuration(false); |
| conf.readFields(in); |
| schema = (String) in.readObject(); |
| fieldTypes = (LogicalType[]) in.readObject(); |
| } |
| |
| private static void setColumn( |
| int rowId, ColumnVector column, LogicalType type, RowData row, int columnId) { |
| if (row.isNullAt(columnId)) { |
| column.noNulls = false; |
| column.isNull[rowId] = true; |
| return; |
| } |
| |
| switch (type.getTypeRoot()) { |
| case CHAR: |
| case VARCHAR: |
| { |
| BytesColumnVector vector = (BytesColumnVector) column; |
| byte[] bytes = row.getString(columnId).toBytes(); |
| vector.setVal(rowId, bytes, 0, bytes.length); |
| break; |
| } |
| case BOOLEAN: |
| { |
| LongColumnVector vector = (LongColumnVector) column; |
| vector.vector[rowId] = row.getBoolean(columnId) ? 1 : 0; |
| break; |
| } |
| case BINARY: |
| case VARBINARY: |
| { |
| BytesColumnVector vector = (BytesColumnVector) column; |
| byte[] bytes = row.getBinary(columnId); |
| vector.setVal(rowId, bytes, 0, bytes.length); |
| break; |
| } |
| case DECIMAL: |
| { |
| DecimalType dt = (DecimalType) type; |
| DecimalColumnVector vector = (DecimalColumnVector) column; |
| vector.set( |
| rowId, |
| HiveDecimal.create( |
| row.getDecimal(columnId, dt.getPrecision(), dt.getScale()) |
| .toBigDecimal())); |
| break; |
| } |
| case TINYINT: |
| { |
| LongColumnVector vector = (LongColumnVector) column; |
| vector.vector[rowId] = row.getByte(columnId); |
| break; |
| } |
| case SMALLINT: |
| { |
| LongColumnVector vector = (LongColumnVector) column; |
| vector.vector[rowId] = row.getShort(columnId); |
| break; |
| } |
| case DATE: |
| case TIME_WITHOUT_TIME_ZONE: |
| case INTEGER: |
| { |
| LongColumnVector vector = (LongColumnVector) column; |
| vector.vector[rowId] = row.getInt(columnId); |
| break; |
| } |
| case BIGINT: |
| { |
| LongColumnVector vector = (LongColumnVector) column; |
| vector.vector[rowId] = row.getLong(columnId); |
| break; |
| } |
| case FLOAT: |
| { |
| DoubleColumnVector vector = (DoubleColumnVector) column; |
| vector.vector[rowId] = row.getFloat(columnId); |
| break; |
| } |
| case DOUBLE: |
| { |
| DoubleColumnVector vector = (DoubleColumnVector) column; |
| vector.vector[rowId] = row.getDouble(columnId); |
| break; |
| } |
| case TIMESTAMP_WITHOUT_TIME_ZONE: |
| { |
| TimestampType tt = (TimestampType) type; |
| Timestamp timestamp = |
| row.getTimestamp(columnId, tt.getPrecision()).toTimestamp(); |
| TimestampColumnVector vector = (TimestampColumnVector) column; |
| vector.set(rowId, timestamp); |
| break; |
| } |
| case TIMESTAMP_WITH_LOCAL_TIME_ZONE: |
| { |
| LocalZonedTimestampType lt = (LocalZonedTimestampType) type; |
| Timestamp timestamp = |
| row.getTimestamp(columnId, lt.getPrecision()).toTimestamp(); |
| TimestampColumnVector vector = (TimestampColumnVector) column; |
| vector.set(rowId, timestamp); |
| break; |
| } |
| default: |
| throw new UnsupportedOperationException("Unsupported type: " + type); |
| } |
| } |
| } |