blob: a5a1facbe6d9a1283f43fcb563e44c875415a401 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.orc.nohive;
import org.apache.flink.api.common.serialization.BulkWriter;
import org.apache.flink.core.fs.FSDataOutputStream;
import org.apache.flink.orc.nohive.writer.NoHivePhysicalWriterImpl;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.TimestampType;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.orc.OrcFile;
import org.apache.orc.TypeDescription;
import org.apache.orc.impl.WriterImpl;
import org.apache.orc.storage.common.type.HiveDecimal;
import org.apache.orc.storage.ql.exec.vector.BytesColumnVector;
import org.apache.orc.storage.ql.exec.vector.ColumnVector;
import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector;
import org.apache.orc.storage.ql.exec.vector.DoubleColumnVector;
import org.apache.orc.storage.ql.exec.vector.LongColumnVector;
import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector;
import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.sql.Timestamp;
import java.util.Properties;
/** A {@link BulkWriter.Factory} from orc no-hive version. */
public class OrcNoHiveBulkWriterFactory implements BulkWriter.Factory<RowData> {
private Configuration conf;
private String schema;
private LogicalType[] fieldTypes;
public OrcNoHiveBulkWriterFactory(Configuration conf, String schema, LogicalType[] fieldTypes) {
this.conf = conf;
this.schema = schema;
this.fieldTypes = fieldTypes;
}
@Override
public BulkWriter<RowData> create(FSDataOutputStream out) throws IOException {
OrcFile.WriterOptions opts = OrcFile.writerOptions(new Properties(), conf);
TypeDescription description = TypeDescription.fromString(schema);
opts.setSchema(description);
opts.physicalWriter(new NoHivePhysicalWriterImpl(out, opts));
WriterImpl writer = new WriterImpl(null, new Path("."), opts);
VectorizedRowBatch rowBatch = description.createRowBatch();
return new BulkWriter<RowData>() {
@Override
public void addElement(RowData row) throws IOException {
int rowId = rowBatch.size++;
for (int i = 0; i < row.getArity(); ++i) {
setColumn(rowId, rowBatch.cols[i], fieldTypes[i], row, i);
}
if (rowBatch.size == rowBatch.getMaxSize()) {
writer.addRowBatch(rowBatch);
rowBatch.reset();
}
}
@Override
public void flush() throws IOException {
if (rowBatch.size != 0) {
writer.addRowBatch(rowBatch);
rowBatch.reset();
}
}
@Override
public void finish() throws IOException {
flush();
writer.close();
}
};
}
// Custom serialization methods
private void writeObject(ObjectOutputStream out) throws IOException {
conf.write(out);
out.writeObject(schema);
out.writeObject(fieldTypes);
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
conf = new Configuration(false);
conf.readFields(in);
schema = (String) in.readObject();
fieldTypes = (LogicalType[]) in.readObject();
}
private static void setColumn(
int rowId, ColumnVector column, LogicalType type, RowData row, int columnId) {
if (row.isNullAt(columnId)) {
column.noNulls = false;
column.isNull[rowId] = true;
return;
}
switch (type.getTypeRoot()) {
case CHAR:
case VARCHAR:
{
BytesColumnVector vector = (BytesColumnVector) column;
byte[] bytes = row.getString(columnId).toBytes();
vector.setVal(rowId, bytes, 0, bytes.length);
break;
}
case BOOLEAN:
{
LongColumnVector vector = (LongColumnVector) column;
vector.vector[rowId] = row.getBoolean(columnId) ? 1 : 0;
break;
}
case BINARY:
case VARBINARY:
{
BytesColumnVector vector = (BytesColumnVector) column;
byte[] bytes = row.getBinary(columnId);
vector.setVal(rowId, bytes, 0, bytes.length);
break;
}
case DECIMAL:
{
DecimalType dt = (DecimalType) type;
DecimalColumnVector vector = (DecimalColumnVector) column;
vector.set(
rowId,
HiveDecimal.create(
row.getDecimal(columnId, dt.getPrecision(), dt.getScale())
.toBigDecimal()));
break;
}
case TINYINT:
{
LongColumnVector vector = (LongColumnVector) column;
vector.vector[rowId] = row.getByte(columnId);
break;
}
case SMALLINT:
{
LongColumnVector vector = (LongColumnVector) column;
vector.vector[rowId] = row.getShort(columnId);
break;
}
case DATE:
case TIME_WITHOUT_TIME_ZONE:
case INTEGER:
{
LongColumnVector vector = (LongColumnVector) column;
vector.vector[rowId] = row.getInt(columnId);
break;
}
case BIGINT:
{
LongColumnVector vector = (LongColumnVector) column;
vector.vector[rowId] = row.getLong(columnId);
break;
}
case FLOAT:
{
DoubleColumnVector vector = (DoubleColumnVector) column;
vector.vector[rowId] = row.getFloat(columnId);
break;
}
case DOUBLE:
{
DoubleColumnVector vector = (DoubleColumnVector) column;
vector.vector[rowId] = row.getDouble(columnId);
break;
}
case TIMESTAMP_WITHOUT_TIME_ZONE:
{
TimestampType tt = (TimestampType) type;
Timestamp timestamp =
row.getTimestamp(columnId, tt.getPrecision()).toTimestamp();
TimestampColumnVector vector = (TimestampColumnVector) column;
vector.set(rowId, timestamp);
break;
}
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
{
LocalZonedTimestampType lt = (LocalZonedTimestampType) type;
Timestamp timestamp =
row.getTimestamp(columnId, lt.getPrecision()).toTimestamp();
TimestampColumnVector vector = (TimestampColumnVector) column;
vector.set(rowId, timestamp);
break;
}
default:
throw new UnsupportedOperationException("Unsupported type: " + type);
}
}
}