blob: 5dae333094210a5bd7236815275b4545d0d9658b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.meta.provider.test;
import static org.apache.beam.vendor.calcite.v1_20_0.com.google.common.base.Preconditions.checkArgument;
import com.google.auto.service.AutoService;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
import org.apache.beam.sdk.extensions.sql.meta.BaseBeamTable;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTableFilter;
import org.apache.beam.sdk.extensions.sql.meta.DefaultTableFilter;
import org.apache.beam.sdk.extensions.sql.meta.Table;
import org.apache.beam.sdk.extensions.sql.meta.provider.InMemoryMetaTableProvider;
import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.FieldTypeDescriptors;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.transforms.Filter;
import org.apache.beam.sdk.schemas.transforms.Select;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexCall;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexLiteral;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeName;
/**
* Test in-memory table provider for use in tests.
*
* <p>Keeps global state and tracks class instances. Works only in DirectRunner.
*/
@AutoService(TableProvider.class)
public class TestTableProvider extends InMemoryMetaTableProvider {
static final Map<Long, Map<String, TableWithRows>> GLOBAL_TABLES = new ConcurrentHashMap<>();
public static final String PUSH_DOWN_OPTION = "push_down";
private static final AtomicLong INSTANCES = new AtomicLong(0);
private final long instanceId = INSTANCES.getAndIncrement();
public TestTableProvider() {
GLOBAL_TABLES.put(instanceId, new ConcurrentHashMap<>());
}
@Override
public String getTableType() {
return "test";
}
public Map<String, TableWithRows> tables() {
return GLOBAL_TABLES.get(instanceId);
}
@Override
public void createTable(Table table) {
tables().put(table.getName(), new TableWithRows(instanceId, table));
}
@Override
public void dropTable(String tableName) {
tables().remove(tableName);
}
@Override
public Map<String, Table> getTables() {
return tables().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().table));
}
@Override
public synchronized BeamSqlTable buildBeamSqlTable(Table table) {
return new InMemoryTable(tables().get(table.getName()));
}
public void addRows(String tableName, Row... rows) {
checkArgument(tables().containsKey(tableName), "Table not found: " + tableName);
tables().get(tableName).rows.addAll(Arrays.asList(rows));
}
public List<Row> tableRows(String tableName) {
return tables().get(tableName).rows;
}
/** TableWitRows. */
public static class TableWithRows implements Serializable {
private Table table;
private List<Row> rows;
private long tableProviderInstanceId;
public TableWithRows(long tableProviderInstanceId, Table table) {
this.tableProviderInstanceId = tableProviderInstanceId;
this.table = table;
this.rows = new CopyOnWriteArrayList<>();
}
public List<Row> getRows() {
return rows;
}
}
private static class InMemoryTable extends BaseBeamTable {
private TableWithRows tableWithRows;
private PushDownOptions options;
@Override
public PCollection.IsBounded isBounded() {
return PCollection.IsBounded.BOUNDED;
}
public InMemoryTable(TableWithRows tableWithRows) {
this.tableWithRows = tableWithRows;
// The reason for introducing a property here is to simplify writing unit tests, testing
// project and predicate push-down behavior when run separate and together.
if (tableWithRows.table.getProperties().containsKey(PUSH_DOWN_OPTION)) {
options =
PushDownOptions.valueOf(
tableWithRows.table.getProperties().getString(PUSH_DOWN_OPTION).toUpperCase());
} else {
options = PushDownOptions.NONE;
}
}
public Coder<Row> rowCoder() {
return SchemaCoder.of(tableWithRows.table.getSchema());
}
@Override
public BeamTableStatistics getTableStatistics(PipelineOptions options) {
return BeamTableStatistics.createBoundedTableStatistics(
(double) tableWithRows.getRows().size());
}
@Override
public PCollection<Row> buildIOReader(PBegin begin) {
TableWithRows tableWithRows =
GLOBAL_TABLES
.get(this.tableWithRows.tableProviderInstanceId)
.get(this.tableWithRows.table.getName());
return begin.apply(Create.of(tableWithRows.rows).withCoder(rowCoder()));
}
@Override
public PCollection<Row> buildIOReader(
PBegin begin, BeamSqlTableFilter filters, List<String> fieldNames) {
if (!(filters instanceof DefaultTableFilter)
&& (options == PushDownOptions.NONE || options == PushDownOptions.PROJECT)) {
throw new RuntimeException(
"Filter push-down is not supported, yet non-default filter was passed.");
}
if ((!fieldNames.isEmpty() && fieldNames.size() < getSchema().getFieldCount())
&& (options == PushDownOptions.NONE || options == PushDownOptions.FILTER)) {
throw new RuntimeException(
"Project push-down is not supported, yet a list of fieldNames was passed.");
}
PCollection<Row> withAllFields = buildIOReader(begin);
if (options == PushDownOptions.NONE) { // needed for testing purposes
return withAllFields;
}
PCollection<Row> result = withAllFields;
// When filter push-down is supported.
if (options == PushDownOptions.FILTER || options == PushDownOptions.BOTH) {
if (filters instanceof TestTableFilter) {
// Create a filter for each supported node.
for (RexNode node : ((TestTableFilter) filters).getSupported()) {
result = result.apply("IOPushDownFilter_" + node.toString(), filterFromNode(node));
}
} else {
throw new RuntimeException(
"Was expecting a filter of type TestTableFilter, but received: "
+ filters.getClass().getSimpleName());
}
}
// When project push-down is supported.
if ((options == PushDownOptions.PROJECT || options == PushDownOptions.BOTH)
&& !fieldNames.isEmpty()) {
result =
result.apply(
"IOPushDownProject",
Select.fieldAccess(
FieldAccessDescriptor.withFieldNames(fieldNames)
.withOrderByFieldInsertionOrder()));
}
return result;
}
@Override
public POutput buildIOWriter(PCollection<Row> input) {
input.apply(ParDo.of(new CollectorFn(tableWithRows)));
return PDone.in(input.getPipeline());
}
@Override
public BeamSqlTableFilter constructFilter(List<RexNode> filter) {
if (options == PushDownOptions.FILTER || options == PushDownOptions.BOTH) {
return new TestTableFilter(filter);
}
return super.constructFilter(filter);
}
@Override
public boolean supportsProjects() {
return options == PushDownOptions.BOTH || options == PushDownOptions.PROJECT;
}
@Override
public Schema getSchema() {
return tableWithRows.table.getSchema();
}
/**
* A helper method to create a {@code Filter} from {@code RexNode}.
*
* @param node {@code RexNode} to create a filter from.
* @return {@code Filter} PTransform.
*/
private PTransform<PCollection<Row>, PCollection<Row>> filterFromNode(RexNode node) {
List<RexNode> operands = new ArrayList<>();
List<Integer> fieldIds = new ArrayList<>();
List<RexLiteral> literals = new ArrayList<>();
List<RexInputRef> inputRefs = new ArrayList<>();
if (node instanceof RexCall) {
operands.addAll(((RexCall) node).getOperands());
} else if (node instanceof RexInputRef) {
operands.add(node);
operands.add(RexLiteral.fromJdbcString(node.getType(), SqlTypeName.BOOLEAN, "true"));
} else {
throw new RuntimeException(
"Was expecting a RexCall or a boolean RexInputRef, but received: "
+ node.getClass().getSimpleName());
}
for (RexNode operand : operands) {
if (operand instanceof RexInputRef) {
RexInputRef inputRef = (RexInputRef) operand;
fieldIds.add(inputRef.getIndex());
inputRefs.add(inputRef);
} else if (operand instanceof RexLiteral) {
RexLiteral literal = (RexLiteral) operand;
literals.add(literal);
} else {
throw new RuntimeException(
"Encountered an unexpected operand: " + operand.getClass().getSimpleName());
}
}
SerializableFunction<Integer, Boolean> comparison;
// TODO: add support for expressions like:
// =(CAST($3):INTEGER NOT NULL, 200)
switch (node.getKind()) {
case LESS_THAN:
comparison = i -> i < 0;
break;
case GREATER_THAN:
comparison = i -> i > 0;
break;
case LESS_THAN_OR_EQUAL:
comparison = i -> i <= 0;
break;
case GREATER_THAN_OR_EQUAL:
comparison = i -> i >= 0;
break;
case EQUALS:
case INPUT_REF:
comparison = i -> i == 0;
break;
case NOT_EQUALS:
comparison = i -> i != 0;
break;
default:
throw new RuntimeException("Unsupported node kind: " + node.getKind().toString());
}
return Filter.<Row>create()
.whereFieldIds(
fieldIds, createFilter(operands, fieldIds, inputRefs, literals, comparison));
}
/**
* A helper method to create a serializable function comparing row fields.
*
* @param operands A list of operands used in a comparison.
* @param fieldIds A list of operand ids.
* @param inputRefs A list of operands, which are an instanceof {@code RexInputRef}.
* @param literals A list of operands, which are an instanceof {@code RexLiteral}.
* @param comparison A comparison to perform between operands.
* @return A filter comparing row fields to literals/other fields.
*/
private SerializableFunction<Row, Boolean> createFilter(
List<RexNode> operands,
List<Integer> fieldIds,
List<RexInputRef> inputRefs,
List<RexLiteral> literals,
SerializableFunction<Integer, Boolean> comparison) {
// Filter push-down only supports comparisons between 2 operands (for now).
assert operands.size() == 2;
// Comparing two columns (2 input refs).
assert inputRefs.size() <= 2;
// Case where we compare 2 Literals should never appear and get optimized away.
assert literals.size() < 2;
if (inputRefs.size() == 2) { // Comparing 2 columns.
final int op0 = fieldIds.indexOf(inputRefs.get(0).getIndex());
final int op1 = fieldIds.indexOf(inputRefs.get(1).getIndex());
return row -> comparison.apply(row.<Comparable>getValue(op0).compareTo(op1));
}
// Comparing a column to a literal.
int fieldSchemaIndex = inputRefs.get(0).getIndex();
FieldType beamFieldType = getSchema().getField(fieldSchemaIndex).getType();
final int op0 = fieldIds.indexOf(fieldSchemaIndex);
// Find Java type of the op0 in Schema
final Comparable op1 =
literals
.get(0)
.<Comparable>getValueAs(
FieldTypeDescriptors.javaTypeForFieldType(beamFieldType).getRawType());
if (operands.get(0) instanceof RexLiteral) { // First operand is a literal
return row -> comparison.apply(op1.compareTo(row.getValue(op0)));
} else if (operands.get(0) instanceof RexInputRef) { // First operand is a column value
return row -> comparison.apply(row.<Comparable>getValue(op0).compareTo(op1));
} else {
throw new RuntimeException(
"Was expecting a RexLiteral and a RexInputRef, but received: "
+ operands.stream()
.map(o -> o.getClass().getSimpleName())
.collect(Collectors.joining(", ")));
}
}
}
private static final class CollectorFn extends DoFn<Row, Row> {
private TableWithRows tableWithRows;
CollectorFn(TableWithRows tableWithRows) {
this.tableWithRows = tableWithRows;
}
@ProcessElement
public void procesElement(ProcessContext context) {
long instanceId = tableWithRows.tableProviderInstanceId;
String tableName = tableWithRows.table.getName();
GLOBAL_TABLES.get(instanceId).get(tableName).rows.add(context.element());
context.output(context.element());
}
}
public enum PushDownOptions {
NONE,
PROJECT,
FILTER,
BOTH
}
}