blob: cdd1444db0ef512dc9fbc647fbd4a302e44cbe20 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.rel;
import static org.apache.beam.vendor.calcite.v1_20_0.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.runners.direct.DirectOptions;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.PipelineResult.State;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.CharType;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.DateType;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.TimeType;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.MetricNameFilter;
import org.apache.beam.sdk.metrics.MetricQueryResults;
import org.apache.beam.sdk.metrics.MetricResult;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.metrics.MetricsFilter;
import org.apache.beam.sdk.options.ApplicationNameOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.runners.TransformHierarchy.Node;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollection.IsBounded;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.EnumerableRel;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.EnumerableRelImplementor;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.PhysType;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.PhysTypeImpl;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.linq4j.Enumerable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.linq4j.Linq4j;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.linq4j.tree.Expression;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.linq4j.tree.Expressions;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.ConventionTraitDef;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptCluster;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptCost;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptPlanner;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitSet;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.convert.ConverterImpl;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType;
import org.joda.time.Duration;
import org.joda.time.ReadableInstant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** BeamRelNode to replace a {@code Enumerable} node. */
public class BeamEnumerableConverter extends ConverterImpl implements EnumerableRel {
private static final Logger LOG = LoggerFactory.getLogger(BeamEnumerableConverter.class);
public BeamEnumerableConverter(RelOptCluster cluster, RelTraitSet traits, RelNode input) {
super(cluster, ConventionTraitDef.INSTANCE, traits, input);
}
@Override
public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
return new BeamEnumerableConverter(getCluster(), traitSet, sole(inputs));
}
@Override
public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
// This should always be a last resort.
return planner.getCostFactory().makeHugeCost();
}
@Override
public Result implement(EnumerableRelImplementor implementor, Prefer prefer) {
final BlockBuilder list = new BlockBuilder();
final RelDataType rowType = getRowType();
final PhysType physType =
PhysTypeImpl.of(implementor.getTypeFactory(), rowType, prefer.preferArray());
final Expression node = implementor.stash((BeamRelNode) getInput(), BeamRelNode.class);
list.add(Expressions.call(BeamEnumerableConverter.class, "toEnumerable", node));
return implementor.result(physType, list.toBlock());
}
public static Enumerable<Object> toEnumerable(BeamRelNode node) {
final ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader();
try {
Thread.currentThread().setContextClassLoader(BeamEnumerableConverter.class.getClassLoader());
final PipelineOptions options = createPipelineOptions(node.getPipelineOptions());
return toEnumerable(options, node);
} finally {
Thread.currentThread().setContextClassLoader(originalClassLoader);
}
}
public static List<Row> toRowList(BeamRelNode node) {
final ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader();
try {
Thread.currentThread().setContextClassLoader(BeamEnumerableConverter.class.getClassLoader());
final PipelineOptions options = createPipelineOptions(node.getPipelineOptions());
return toRowList(options, node);
} finally {
Thread.currentThread().setContextClassLoader(originalClassLoader);
}
}
public static PipelineOptions createPipelineOptions(Map<String, String> map) {
final String[] args = new String[map.size()];
int i = 0;
for (Map.Entry<String, String> entry : map.entrySet()) {
args[i++] = "--" + entry.getKey() + "=" + entry.getValue();
}
PipelineOptions options = PipelineOptionsFactory.fromArgs(args).withValidation().create();
options.as(ApplicationNameOptions.class).setAppName("BeamSql");
return options;
}
static List<Row> toRowList(PipelineOptions options, BeamRelNode node) {
if (node instanceof BeamIOSinkRel) {
throw new UnsupportedOperationException("Does not support BeamIOSinkRel in toRowList.");
} else if (isLimitQuery(node)) {
throw new UnsupportedOperationException("Does not support queries with LIMIT in toRowList.");
}
return collectRows(options, node).stream().collect(Collectors.toList());
}
static Enumerable<Object> toEnumerable(PipelineOptions options, BeamRelNode node) {
if (node instanceof BeamIOSinkRel) {
return count(options, node);
} else if (isLimitQuery(node)) {
return limitCollect(options, node);
}
return Linq4j.asEnumerable(rowToAvaticaAndUnboxValues((collectRows(options, node))));
}
private static PipelineResult limitRun(
PipelineOptions options,
BeamRelNode node,
DoFn<Row, Void> doFn,
Queue<Row> values,
int limitCount) {
options.as(DirectOptions.class).setBlockOnRun(false);
Pipeline pipeline = Pipeline.create(options);
PCollection<Row> resultCollection = BeamSqlRelUtils.toPCollection(pipeline, node);
resultCollection.apply(ParDo.of(doFn));
PipelineResult result = pipeline.run();
State state;
while (true) {
// Check pipeline state in every second
state = result.waitUntilFinish(Duration.standardSeconds(1));
if (state != null && state.isTerminal()) {
break;
}
try {
if (values.size() >= limitCount) {
result.cancel();
break;
}
} catch (IOException e) {
LOG.warn(e.toString());
break;
}
}
return result;
}
private static void runCollector(PipelineOptions options, BeamRelNode node) {
Pipeline pipeline = Pipeline.create(options);
PCollection<Row> resultCollection = BeamSqlRelUtils.toPCollection(pipeline, node);
resultCollection.apply(ParDo.of(new Collector()));
PipelineResult result = pipeline.run();
result.waitUntilFinish();
}
private static Queue<Row> collectRows(PipelineOptions options, BeamRelNode node) {
long id = options.getOptionsId();
Queue<Row> values = new ConcurrentLinkedQueue<>();
checkArgument(
options
.getRunner()
.getCanonicalName()
.equals("org.apache.beam.runners.direct.DirectRunner"),
"collectRowList is only available in direct runner.");
Collector.globalValues.put(id, values);
runCollector(options, node);
Collector.globalValues.remove(id);
return values;
}
private static Enumerable<Object> limitCollect(PipelineOptions options, BeamRelNode node) {
long id = options.getOptionsId();
ConcurrentLinkedQueue<Row> values = new ConcurrentLinkedQueue<>();
checkArgument(
options
.getRunner()
.getCanonicalName()
.equals("org.apache.beam.runners.direct.DirectRunner"),
"SELECT without INSERT is only supported in DirectRunner in SQL Shell.");
int limitCount = getLimitCount(node);
Collector.globalValues.put(id, values);
limitRun(options, node, new Collector(), values, limitCount);
Collector.globalValues.remove(id);
// remove extra retrieved values
while (values.size() > limitCount) {
values.remove();
}
return Linq4j.asEnumerable(rowToAvaticaAndUnboxValues(values));
}
private static class Collector extends DoFn<Row, Void> {
// This will only work on the direct runner.
private static final Map<Long, Queue<Row>> globalValues = new ConcurrentHashMap<>();
@Nullable private volatile Queue<Row> values;
@StartBundle
public void startBundle(StartBundleContext context) {
long id = context.getPipelineOptions().getOptionsId();
values = globalValues.get(id);
}
@ProcessElement
public void processElement(ProcessContext context) {
values.add(context.element());
}
}
private static List<Object> rowToAvaticaAndUnboxValues(Queue<Row> values) {
return values.stream()
.map(
row -> {
Object[] objects = rowToAvatica(row);
if (objects.length == 1) {
// if objects.length == 1, that means input Row contains only 1 column/element,
// then an Object instead of Object[] should be returned because of
// CalciteResultSet's behaviour that tries to convert one column row to an Object.
return objects[0];
} else {
return objects;
}
})
.collect(Collectors.toList());
}
private static Object[] rowToAvatica(Row row) {
Schema schema = row.getSchema();
Object[] convertedColumns = new Object[schema.getFields().size()];
int i = 0;
for (Schema.Field field : schema.getFields()) {
convertedColumns[i] = fieldToAvatica(field.getType(), row.getValue(i));
++i;
}
return convertedColumns;
}
private static Object fieldToAvatica(Schema.FieldType type, Object beamValue) {
if (beamValue == null) {
return null;
}
switch (type.getTypeName()) {
case LOGICAL_TYPE:
String logicalId = type.getLogicalType().getIdentifier();
if (logicalId.equals(TimeType.IDENTIFIER)) {
return (int) ((ReadableInstant) beamValue).getMillis();
} else if (logicalId.equals(DateType.IDENTIFIER)) {
return (int) (((ReadableInstant) beamValue).getMillis() / MILLIS_PER_DAY);
} else if (logicalId.equals(CharType.IDENTIFIER)) {
return beamValue;
} else {
throw new IllegalArgumentException("Unknown DateTime type " + logicalId);
}
case DATETIME:
return ((ReadableInstant) beamValue).getMillis();
case BYTE:
case INT16:
case INT32:
case INT64:
case DECIMAL:
case FLOAT:
case DOUBLE:
case STRING:
case BOOLEAN:
case BYTES:
return beamValue;
case ARRAY:
return ((List<?>) beamValue)
.stream()
.map(elem -> fieldToAvatica(type.getCollectionElementType(), elem))
.collect(Collectors.toList());
case MAP:
return ((Map<?, ?>) beamValue)
.entrySet().stream()
.collect(
Collectors.toMap(
entry -> entry.getKey(),
entry ->
fieldToAvatica(type.getCollectionElementType(), entry.getValue())));
case ROW:
// TODO: needs to be a Struct
return beamValue;
default:
throw new IllegalStateException(
String.format("Unreachable case for Beam typename %s", type.getTypeName()));
}
}
private static Enumerable<Object> count(PipelineOptions options, BeamRelNode node) {
Pipeline pipeline = Pipeline.create(options);
BeamSqlRelUtils.toPCollection(pipeline, node).apply(ParDo.of(new RowCounter()));
PipelineResult result = pipeline.run();
long count = 0;
if (!containsUnboundedPCollection(pipeline)) {
result.waitUntilFinish();
MetricQueryResults metrics =
result
.metrics()
.queryMetrics(
MetricsFilter.builder()
.addNameFilter(MetricNameFilter.named(BeamEnumerableConverter.class, "rows"))
.build());
Iterator<MetricResult<Long>> iterator = metrics.getCounters().iterator();
if (iterator.hasNext()) {
count = iterator.next().getAttempted();
}
}
return Linq4j.singletonEnumerable(count);
}
private static class RowCounter extends DoFn<Row, Void> {
final Counter rows = Metrics.counter(BeamEnumerableConverter.class, "rows");
@ProcessElement
public void processElement(ProcessContext context) {
rows.inc();
}
}
private static boolean isLimitQuery(BeamRelNode node) {
return (node instanceof BeamSortRel && ((BeamSortRel) node).isLimitOnly())
|| (node instanceof BeamCalcRel && ((BeamCalcRel) node).isInputSortRelAndLimitOnly());
}
private static int getLimitCount(BeamRelNode node) {
if (node instanceof BeamSortRel) {
return ((BeamSortRel) node).getCount();
} else if (node instanceof BeamCalcRel) {
return ((BeamCalcRel) node).getLimitCountOfSortRel();
}
throw new RuntimeException(
"Cannot get limit count from RelNode tree with root " + node.getRelTypeName());
}
private static boolean containsUnboundedPCollection(Pipeline p) {
class BoundednessVisitor extends PipelineVisitor.Defaults {
IsBounded boundedness = IsBounded.BOUNDED;
@Override
public void visitValue(PValue value, Node producer) {
if (value instanceof PCollection) {
boundedness = boundedness.and(((PCollection) value).isBounded());
}
}
}
BoundednessVisitor visitor = new BoundednessVisitor();
p.traverseTopologically(visitor);
return visitor.boundedness == IsBounded.UNBOUNDED;
}
}