blob: 00b3940c98509aa8a3d1c11a0fc410ceb5000fe5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.orc;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.io.FileInputFormat;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.core.fs.FileInputSplit;
import org.apache.flink.core.fs.Path;
import org.apache.flink.types.Row;
import org.apache.hadoop.conf.Configuration;
import org.apache.orc.OrcConf;
import org.apache.orc.OrcFile;
import org.apache.orc.Reader;
import org.apache.orc.RecordReader;
import org.apache.orc.StripeInformation;
import org.apache.orc.TypeDescription;
import org.apache.orc.storage.common.type.HiveDecimal;
import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch;
import org.apache.orc.storage.ql.io.sarg.PredicateLeaf;
import org.apache.orc.storage.ql.io.sarg.SearchArgument;
import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory;
import org.apache.orc.storage.serde2.io.HiveDecimalWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.apache.flink.orc.OrcBatchReader.fillRows;
/**
* InputFormat to read ORC files.
*/
public class OrcRowInputFormat extends FileInputFormat<Row> implements ResultTypeQueryable<Row> {
private static final Logger LOG = LoggerFactory.getLogger(OrcRowInputFormat.class);
// the number of rows read in a batch
private static final int DEFAULT_BATCH_SIZE = 1000;
// the number of fields rows to read in a batch
private int batchSize;
// the configuration to read with
private Configuration conf;
// the schema of the ORC files to read
private TypeDescription schema;
// the fields of the ORC schema that the returned Rows are composed of.
private int[] selectedFields;
// the type information of the Rows returned by this InputFormat.
private transient RowTypeInfo rowType;
// the ORC reader
private transient RecordReader orcRowsReader;
// the vectorized row data to be read in a batch
private transient VectorizedRowBatch rowBatch;
// the vector of rows that is read in a batch
private transient Row[] rows;
// the number of rows in the current batch
private transient int rowsInBatch;
// the index of the next row to return
private transient int nextRow;
private ArrayList<Predicate> conjunctPredicates = new ArrayList<>();
/**
* Creates an OrcRowInputFormat.
*
* @param path The path to read ORC files from.
* @param schemaString The schema of the ORC files as String.
* @param orcConfig The configuration to read the ORC files with.
*/
public OrcRowInputFormat(String path, String schemaString, Configuration orcConfig) {
this(path, TypeDescription.fromString(schemaString), orcConfig, DEFAULT_BATCH_SIZE);
}
/**
* Creates an OrcRowInputFormat.
*
* @param path The path to read ORC files from.
* @param schemaString The schema of the ORC files as String.
* @param orcConfig The configuration to read the ORC files with.
* @param batchSize The number of Row objects to read in a batch.
*/
public OrcRowInputFormat(String path, String schemaString, Configuration orcConfig, int batchSize) {
this(path, TypeDescription.fromString(schemaString), orcConfig, batchSize);
}
/**
* Creates an OrcRowInputFormat.
*
* @param path The path to read ORC files from.
* @param orcSchema The schema of the ORC files as ORC TypeDescription.
* @param orcConfig The configuration to read the ORC files with.
* @param batchSize The number of Row objects to read in a batch.
*/
public OrcRowInputFormat(String path, TypeDescription orcSchema, Configuration orcConfig, int batchSize) {
super(new Path(path));
// configure OrcRowInputFormat
this.schema = orcSchema;
this.rowType = (RowTypeInfo) OrcBatchReader.schemaToTypeInfo(schema);
this.conf = orcConfig;
this.batchSize = batchSize;
// set default selection mask, i.e., all fields.
this.selectedFields = new int[this.schema.getChildren().size()];
for (int i = 0; i < selectedFields.length; i++) {
this.selectedFields[i] = i;
}
}
/**
* Adds a filter predicate to reduce the number of rows to be returned by the input format.
* Multiple conjunctive predicates can be added by calling this method multiple times.
*
* <p>Note: Predicates can significantly reduce the amount of data that is read.
* However, the OrcRowInputFormat does not guarantee that all returned rows qualify the
* predicates. Moreover, predicates are only applied if the referenced field is among the
* selected fields.
*
* @param predicate The filter predicate.
*/
public void addPredicate(Predicate predicate) {
// validate
validatePredicate(predicate);
// add predicate
this.conjunctPredicates.add(predicate);
}
private void validatePredicate(Predicate pred) {
if (pred instanceof ColumnPredicate) {
// check column name
String colName = ((ColumnPredicate) pred).columnName;
if (!this.schema.getFieldNames().contains(colName)) {
throw new IllegalArgumentException("Predicate cannot be applied. " +
"Column '" + colName + "' does not exist in ORC schema.");
}
} else if (pred instanceof Not) {
validatePredicate(((Not) pred).child());
} else if (pred instanceof Or) {
for (Predicate p : ((Or) pred).children()) {
validatePredicate(p);
}
}
}
/**
* Selects the fields from the ORC schema that are returned by InputFormat.
*
* @param selectedFields The indices of the fields of the ORC schema that are returned by the InputFormat.
*/
public void selectFields(int... selectedFields) {
// set field mapping
this.selectedFields = selectedFields;
// adapt result type
this.rowType = RowTypeInfo.projectFields(this.rowType, selectedFields);
}
/**
* Computes the ORC projection mask of the fields to include from the selected fields.rowOrcInputFormat.nextRecord(null).
*
* @return The ORC projection mask.
*/
private boolean[] computeProjectionMask() {
// mask with all fields of the schema
boolean[] projectionMask = new boolean[schema.getMaximumId() + 1];
// for each selected field
for (int inIdx : selectedFields) {
// set all nested fields of a selected field to true
TypeDescription fieldSchema = schema.getChildren().get(inIdx);
for (int i = fieldSchema.getId(); i <= fieldSchema.getMaximumId(); i++) {
projectionMask[i] = true;
}
}
return projectionMask;
}
@Override
public void openInputFormat() throws IOException {
super.openInputFormat();
// create and initialize the row batch
this.rows = new Row[batchSize];
for (int i = 0; i < batchSize; i++) {
rows[i] = new Row(selectedFields.length);
}
}
@Override
public void open(FileInputSplit fileSplit) throws IOException {
LOG.debug("Opening ORC file {}", fileSplit.getPath());
// open ORC file and create reader
org.apache.hadoop.fs.Path hPath = new org.apache.hadoop.fs.Path(fileSplit.getPath().getPath());
Reader orcReader = OrcFile.createReader(hPath, OrcFile.readerOptions(conf));
// get offset and length for the stripes that start in the split
Tuple2<Long, Long> offsetAndLength = getOffsetAndLengthForSplit(fileSplit, getStripes(orcReader));
// create ORC row reader configuration
Reader.Options options = getOptions(orcReader)
.schema(schema)
.range(offsetAndLength.f0, offsetAndLength.f1)
.useZeroCopy(OrcConf.USE_ZEROCOPY.getBoolean(conf))
.skipCorruptRecords(OrcConf.SKIP_CORRUPT_DATA.getBoolean(conf))
.tolerateMissingSchema(OrcConf.TOLERATE_MISSING_SCHEMA.getBoolean(conf));
// configure filters
if (!conjunctPredicates.isEmpty()) {
SearchArgument.Builder b = SearchArgumentFactory.newBuilder();
b = b.startAnd();
for (Predicate predicate : conjunctPredicates) {
predicate.add(b);
}
b = b.end();
options.searchArgument(b.build(), new String[]{});
}
// configure selected fields
options.include(computeProjectionMask());
// create ORC row reader
this.orcRowsReader = orcReader.rows(options);
// assign ids
this.schema.getId();
// create row batch
this.rowBatch = schema.createRowBatch(batchSize);
rowsInBatch = 0;
nextRow = 0;
}
@VisibleForTesting
Reader.Options getOptions(Reader orcReader) {
return orcReader.options();
}
@VisibleForTesting
List<StripeInformation> getStripes(Reader orcReader) {
return orcReader.getStripes();
}
private Tuple2<Long, Long> getOffsetAndLengthForSplit(FileInputSplit split, List<StripeInformation> stripes) {
long splitStart = split.getStart();
long splitEnd = splitStart + split.getLength();
long readStart = Long.MAX_VALUE;
long readEnd = Long.MIN_VALUE;
for (StripeInformation s : stripes) {
if (splitStart <= s.getOffset() && s.getOffset() < splitEnd) {
// stripe starts in split, so it is included
readStart = Math.min(readStart, s.getOffset());
readEnd = Math.max(readEnd, s.getOffset() + s.getLength());
}
}
if (readStart < Long.MAX_VALUE) {
// at least one split is included
return Tuple2.of(readStart, readEnd - readStart);
} else {
return Tuple2.of(0L, 0L);
}
}
@Override
public void close() throws IOException {
if (orcRowsReader != null) {
this.orcRowsReader.close();
}
this.orcRowsReader = null;
}
@Override
public void closeInputFormat() throws IOException {
this.rows = null;
this.rows = null;
this.schema = null;
this.rowBatch = null;
}
@Override
public boolean reachedEnd() throws IOException {
return !ensureBatch();
}
/**
* Checks if there is at least one row left in the batch to return.
* If no more row are available, it reads another batch of rows.
*
* @return Returns true if there is one more row to return, false otherwise.
* @throws IOException throw if an exception happens while reading a batch.
*/
private boolean ensureBatch() throws IOException {
if (nextRow >= rowsInBatch) {
// No more rows available in the Rows array.
nextRow = 0;
// Try to read the next batch if rows from the ORC file.
boolean moreRows = orcRowsReader.nextBatch(rowBatch);
if (moreRows) {
// Load the data into the Rows array.
rowsInBatch = fillRows(rows, schema, rowBatch, selectedFields);
}
return moreRows;
}
// there is at least one Row left in the Rows array.
return true;
}
@Override
public Row nextRecord(Row reuse) throws IOException {
// return the next row
return rows[this.nextRow++];
}
@Override
public TypeInformation<Row> getProducedType() {
return rowType;
}
// --------------------------------------------------------------------------------------------
// Custom serialization methods
// --------------------------------------------------------------------------------------------
private void writeObject(ObjectOutputStream out) throws IOException {
out.writeInt(batchSize);
this.conf.write(out);
out.writeUTF(schema.toString());
out.writeInt(selectedFields.length);
for (int f : selectedFields) {
out.writeInt(f);
}
out.writeInt(conjunctPredicates.size());
for (Predicate p : conjunctPredicates) {
out.writeObject(p);
}
}
@SuppressWarnings("unchecked")
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
batchSize = in.readInt();
org.apache.hadoop.conf.Configuration configuration = new org.apache.hadoop.conf.Configuration();
configuration.readFields(in);
if (this.conf == null) {
this.conf = configuration;
}
this.schema = TypeDescription.fromString(in.readUTF());
this.selectedFields = new int[in.readInt()];
for (int i = 0; i < selectedFields.length; i++) {
this.selectedFields[i] = in.readInt();
}
this.conjunctPredicates = new ArrayList<>();
int numPreds = in.readInt();
for (int i = 0; i < numPreds; i++) {
conjunctPredicates.add((Predicate) in.readObject());
}
}
@Override
public boolean supportsMultiPaths() {
return true;
}
// --------------------------------------------------------------------------------------------
// Getter methods for tests
// --------------------------------------------------------------------------------------------
@VisibleForTesting
Configuration getConfiguration() {
return conf;
}
@VisibleForTesting
int getBatchSize() {
return batchSize;
}
@VisibleForTesting
String getSchema() {
return schema.toString();
}
// --------------------------------------------------------------------------------------------
// Classes to define predicates
// --------------------------------------------------------------------------------------------
/**
* A filter predicate that can be evaluated by the OrcRowInputFormat.
*/
public abstract static class Predicate implements Serializable {
protected abstract SearchArgument.Builder add(SearchArgument.Builder builder);
}
abstract static class ColumnPredicate extends Predicate {
final String columnName;
final PredicateLeaf.Type literalType;
ColumnPredicate(String columnName, PredicateLeaf.Type literalType) {
this.columnName = columnName;
this.literalType = literalType;
}
Object castLiteral(Serializable literal) {
switch (literalType) {
case LONG:
if (literal instanceof Byte) {
return new Long((Byte) literal);
} else if (literal instanceof Short) {
return new Long((Short) literal);
} else if (literal instanceof Integer) {
return new Long((Integer) literal);
} else if (literal instanceof Long) {
return literal;
} else {
throw new IllegalArgumentException("A predicate on a LONG column requires an integer " +
"literal, i.e., Byte, Short, Integer, or Long.");
}
case FLOAT:
if (literal instanceof Float) {
return new Double((Float) literal);
} else if (literal instanceof Double) {
return literal;
} else if (literal instanceof BigDecimal) {
return ((BigDecimal) literal).doubleValue();
} else {
throw new IllegalArgumentException("A predicate on a FLOAT column requires a floating " +
"literal, i.e., Float or Double.");
}
case STRING:
if (literal instanceof String) {
return literal;
} else {
throw new IllegalArgumentException("A predicate on a STRING column requires a floating " +
"literal, i.e., Float or Double.");
}
case BOOLEAN:
if (literal instanceof Boolean) {
return literal;
} else {
throw new IllegalArgumentException("A predicate on a BOOLEAN column requires a Boolean literal.");
}
case DATE:
if (literal instanceof Date) {
return literal;
} else {
throw new IllegalArgumentException("A predicate on a DATE column requires a java.sql.Date literal.");
}
case TIMESTAMP:
if (literal instanceof Timestamp) {
return literal;
} else {
throw new IllegalArgumentException("A predicate on a TIMESTAMP column requires a java.sql.Timestamp literal.");
}
case DECIMAL:
if (literal instanceof BigDecimal) {
return new HiveDecimalWritable(HiveDecimal.create((BigDecimal) literal));
} else {
throw new IllegalArgumentException("A predicate on a DECIMAL column requires a BigDecimal literal.");
}
default:
throw new IllegalArgumentException("Unknown literal type " + literalType);
}
}
}
abstract static class BinaryPredicate extends ColumnPredicate {
final Serializable literal;
BinaryPredicate(String columnName, PredicateLeaf.Type literalType, Serializable literal) {
super(columnName, literalType);
this.literal = literal;
}
}
/**
* An EQUALS predicate that can be evaluated by the OrcRowInputFormat.
*/
public static class Equals extends BinaryPredicate {
/**
* Creates an EQUALS predicate.
*
* @param columnName The column to check.
* @param literalType The type of the literal.
* @param literal The literal value to check the column against.
*/
public Equals(String columnName, PredicateLeaf.Type literalType, Serializable literal) {
super(columnName, literalType, literal);
}
@Override
protected SearchArgument.Builder add(SearchArgument.Builder builder) {
return builder.equals(columnName, literalType, castLiteral(literal));
}
@Override
public String toString() {
return columnName + " = " + literal;
}
}
/**
* An EQUALS predicate that can be evaluated with Null safety by the OrcRowInputFormat.
*/
public static class NullSafeEquals extends BinaryPredicate {
/**
* Creates a null-safe EQUALS predicate.
*
* @param columnName The column to check.
* @param literalType The type of the literal.
* @param literal The literal value to check the column against.
*/
public NullSafeEquals(String columnName, PredicateLeaf.Type literalType, Serializable literal) {
super(columnName, literalType, literal);
}
@Override
protected SearchArgument.Builder add(SearchArgument.Builder builder) {
return builder.nullSafeEquals(columnName, literalType, castLiteral(literal));
}
@Override
public String toString() {
return columnName + " = " + literal;
}
}
/**
* A LESS_THAN predicate that can be evaluated by the OrcRowInputFormat.
*/
public static class LessThan extends BinaryPredicate {
/**
* Creates a LESS_THAN predicate.
*
* @param columnName The column to check.
* @param literalType The type of the literal.
* @param literal The literal value to check the column against.
*/
public LessThan(String columnName, PredicateLeaf.Type literalType, Serializable literal) {
super(columnName, literalType, literal);
}
@Override
protected SearchArgument.Builder add(SearchArgument.Builder builder) {
return builder.lessThan(columnName, literalType, castLiteral(literal));
}
@Override
public String toString() {
return columnName + " < " + literal;
}
}
/**
* A LESS_THAN_EQUALS predicate that can be evaluated by the OrcRowInputFormat.
*/
public static class LessThanEquals extends BinaryPredicate {
/**
* Creates a LESS_THAN_EQUALS predicate.
*
* @param columnName The column to check.
* @param literalType The type of the literal.
* @param literal The literal value to check the column against.
*/
public LessThanEquals(String columnName, PredicateLeaf.Type literalType, Serializable literal) {
super(columnName, literalType, literal);
}
@Override
protected SearchArgument.Builder add(SearchArgument.Builder builder) {
return builder.lessThanEquals(columnName, literalType, castLiteral(literal));
}
@Override
public String toString() {
return columnName + " <= " + literal;
}
}
/**
* An IS_NULL predicate that can be evaluated by the OrcRowInputFormat.
*/
public static class IsNull extends ColumnPredicate {
/**
* Creates an IS_NULL predicate.
*
* @param columnName The column to check for null.
* @param literalType The type of the column to check for null.
*/
public IsNull(String columnName, PredicateLeaf.Type literalType) {
super(columnName, literalType);
}
@Override
protected SearchArgument.Builder add(SearchArgument.Builder builder) {
return builder.isNull(columnName, literalType);
}
@Override
public String toString() {
return columnName + " IS NULL";
}
}
/**
* An BETWEEN predicate that can be evaluated by the OrcRowInputFormat.
*/
public static class Between extends ColumnPredicate {
private Serializable lowerBound;
private Serializable upperBound;
/**
* Creates an BETWEEN predicate.
*
* @param columnName The column to check.
* @param literalType The type of the literals.
* @param lowerBound The literal value of the (inclusive) lower bound to check the column against.
* @param upperBound The literal value of the (inclusive) upper bound to check the column against.
*/
public Between(String columnName, PredicateLeaf.Type literalType, Serializable lowerBound, Serializable upperBound) {
super(columnName, literalType);
this.lowerBound = lowerBound;
this.upperBound = upperBound;
}
@Override
protected SearchArgument.Builder add(SearchArgument.Builder builder) {
return builder.between(columnName, literalType, castLiteral(lowerBound), castLiteral(upperBound));
}
@Override
public String toString() {
return lowerBound + " <= " + columnName + " <= " + upperBound;
}
}
/**
* An IN predicate that can be evaluated by the OrcRowInputFormat.
*/
public static class In extends ColumnPredicate {
private Serializable[] literals;
/**
* Creates an IN predicate.
*
* @param columnName The column to check.
* @param literalType The type of the literals.
* @param literals The literal values to check the column against.
*/
public In(String columnName, PredicateLeaf.Type literalType, Serializable... literals) {
super(columnName, literalType);
this.literals = literals;
}
@Override
protected SearchArgument.Builder add(SearchArgument.Builder builder) {
Object[] castedLiterals = new Object[literals.length];
for (int i = 0; i < literals.length; i++) {
castedLiterals[i] = castLiteral(literals[i]);
}
return builder.in(columnName, literalType, (Object[]) castedLiterals);
}
@Override
public String toString() {
return columnName + " IN " + Arrays.toString(literals);
}
}
/**
* A NOT predicate to negate a predicate that can be evaluated by the OrcRowInputFormat.
*/
public static class Not extends Predicate {
private final Predicate pred;
/**
* Creates a NOT predicate.
*
* @param predicate The predicate to negate.
*/
public Not(Predicate predicate) {
this.pred = predicate;
}
protected SearchArgument.Builder add(SearchArgument.Builder builder) {
return pred.add(builder.startNot()).end();
}
protected Predicate child() {
return pred;
}
@Override
public String toString() {
return "NOT(" + pred.toString() + ")";
}
}
/**
* An OR predicate that can be evaluated by the OrcRowInputFormat.
*/
public static class Or extends Predicate {
private final Predicate[] preds;
/**
* Creates an OR predicate.
*
* @param predicates The disjunctive predicates.
*/
public Or(Predicate... predicates) {
this.preds = predicates;
}
@Override
protected SearchArgument.Builder add(SearchArgument.Builder builder) {
SearchArgument.Builder withOr = builder.startOr();
for (Predicate p : preds) {
withOr = p.add(withOr);
}
return withOr.end();
}
protected Iterable<Predicate> children() {
return Arrays.asList(preds);
}
@Override
public String toString() {
return "OR(" + Arrays.toString(preds) + ")";
}
}
}