| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.sdk.extensions.sql.impl.rel; |
| |
| import static org.apache.beam.sdk.schemas.Schema.FieldType; |
| import static org.apache.beam.sdk.schemas.Schema.TypeName; |
| import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument; |
| import static org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY; |
| |
| import java.lang.reflect.InvocationTargetException; |
| import java.lang.reflect.Method; |
| import java.lang.reflect.Type; |
| import java.math.BigDecimal; |
| import java.util.AbstractList; |
| import java.util.List; |
| import java.util.Map; |
| import javax.annotation.Nullable; |
| import org.apache.beam.sdk.extensions.sql.impl.planner.BeamJavaTypeFactory; |
| import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; |
| import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.CharType; |
| import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.DateType; |
| import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.TimeType; |
| import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.TimeWithLocalTzType; |
| import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.TimestampWithLocalTzType; |
| import org.apache.beam.sdk.schemas.Schema; |
| import org.apache.beam.sdk.transforms.DoFn; |
| import org.apache.beam.sdk.transforms.PTransform; |
| import org.apache.beam.sdk.transforms.ParDo; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.PCollectionList; |
| import org.apache.beam.sdk.values.Row; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap; |
| import org.apache.calcite.DataContext; |
| import org.apache.calcite.adapter.enumerable.JavaRowFormat; |
| import org.apache.calcite.adapter.enumerable.PhysType; |
| import org.apache.calcite.adapter.enumerable.PhysTypeImpl; |
| import org.apache.calcite.adapter.enumerable.RexToLixTranslator; |
| import org.apache.calcite.adapter.java.JavaTypeFactory; |
| import org.apache.calcite.linq4j.QueryProvider; |
| import org.apache.calcite.linq4j.tree.BlockBuilder; |
| import org.apache.calcite.linq4j.tree.Expression; |
| import org.apache.calcite.linq4j.tree.Expressions; |
| import org.apache.calcite.linq4j.tree.GotoExpressionKind; |
| import org.apache.calcite.linq4j.tree.ParameterExpression; |
| import org.apache.calcite.linq4j.tree.Types; |
| import org.apache.calcite.plan.RelOptCluster; |
| import org.apache.calcite.plan.RelOptPredicateList; |
| import org.apache.calcite.plan.RelTraitSet; |
| import org.apache.calcite.rel.RelNode; |
| import org.apache.calcite.rel.core.Calc; |
| import org.apache.calcite.rel.metadata.RelMetadataQuery; |
| import org.apache.calcite.rex.RexBuilder; |
| import org.apache.calcite.rex.RexProgram; |
| import org.apache.calcite.rex.RexSimplify; |
| import org.apache.calcite.rex.RexUtil; |
| import org.apache.calcite.schema.SchemaPlus; |
| import org.apache.calcite.sql.validate.SqlConformance; |
| import org.apache.calcite.sql.validate.SqlConformanceEnum; |
| import org.apache.calcite.util.BuiltInMethod; |
| import org.codehaus.commons.compiler.CompileException; |
| import org.codehaus.janino.ScriptEvaluator; |
| import org.joda.time.DateTime; |
| import org.joda.time.DateTimeZone; |
| import org.joda.time.ReadableInstant; |
| |
| /** BeamRelNode to replace a {@code Project} node. */ |
| public class BeamCalcRel extends Calc implements BeamRelNode { |
| |
| private static final ParameterExpression outputSchemaParam = |
| Expressions.parameter(Schema.class, "outputSchema"); |
| private static final ParameterExpression processContextParam = |
| Expressions.parameter(DoFn.ProcessContext.class, "c"); |
| |
| public BeamCalcRel(RelOptCluster cluster, RelTraitSet traits, RelNode input, RexProgram program) { |
| super(cluster, traits, input, program); |
| } |
| |
| @Override |
| public Calc copy(RelTraitSet traitSet, RelNode input, RexProgram program) { |
| return new BeamCalcRel(getCluster(), traitSet, input, program); |
| } |
| |
| @Override |
| public PTransform<PCollectionList<Row>, PCollection<Row>> buildPTransform() { |
| return new Transform(); |
| } |
| |
| private class Transform extends PTransform<PCollectionList<Row>, PCollection<Row>> { |
| |
| /** |
| * expand is based on calcite's EnumerableCalc.implement(). This function generates java code |
| * executed in the processElement in CalcFn using Calcite's linq4j library. It generates a block |
| * of code using a BlockBuilder. The root of the block is an if statement with any conditions. |
| * Inside that if statement, a new record is output with a row containing transformed fields. |
| * The InputGetterImpl class generates code to read from the input record and convert to Calcite |
| * types. Calcite then generates code for any function calls or other operations. Then the |
| * castOutput method generates code to convert back to Beam Schema types. |
| */ |
| @Override |
| public PCollection<Row> expand(PCollectionList<Row> pinput) { |
| checkArgument( |
| pinput.size() == 1, |
| "Wrong number of inputs for %s: %s", |
| BeamCalcRel.class.getSimpleName(), |
| pinput); |
| PCollection<Row> upstream = pinput.get(0); |
| Schema outputSchema = CalciteUtils.toSchema(getRowType()); |
| |
| final SqlConformance conformance = SqlConformanceEnum.MYSQL_5; |
| final JavaTypeFactory typeFactory = BeamJavaTypeFactory.INSTANCE; |
| final BlockBuilder builder = new BlockBuilder(); |
| |
| final PhysType physType = |
| PhysTypeImpl.of(typeFactory, getRowType(), JavaRowFormat.ARRAY, false); |
| |
| Expression input = |
| Expressions.convert_(Expressions.call(processContextParam, "element"), Row.class); |
| |
| final RexBuilder rexBuilder = getCluster().getRexBuilder(); |
| final RelMetadataQuery mq = RelMetadataQuery.instance(); |
| final RelOptPredicateList predicates = mq.getPulledUpPredicates(getInput()); |
| final RexSimplify simplify = new RexSimplify(rexBuilder, predicates, false, RexUtil.EXECUTOR); |
| final RexProgram program = BeamCalcRel.this.program.normalize(rexBuilder, simplify); |
| |
| Expression condition = |
| RexToLixTranslator.translateCondition( |
| program, |
| typeFactory, |
| builder, |
| new InputGetterImpl(input, upstream.getSchema()), |
| null, |
| conformance); |
| |
| List<Expression> expressions = |
| RexToLixTranslator.translateProjects( |
| program, |
| typeFactory, |
| conformance, |
| builder, |
| physType, |
| DataContext.ROOT, |
| new InputGetterImpl(input, upstream.getSchema()), |
| null); |
| |
| // Expressions.call is equivalent to: output = Row.withSchema(outputSchema) |
| Expression output = Expressions.call(Row.class, "withSchema", outputSchemaParam); |
| Method addValue = Types.lookupMethod(Row.Builder.class, "addValue", Object.class); |
| |
| for (int index = 0; index < expressions.size(); index++) { |
| Expression value = expressions.get(index); |
| FieldType toType = outputSchema.getField(index).getType(); |
| |
| // Expressions.call is equivalent to: .addValue(value) |
| output = Expressions.call(output, addValue, castOutput(value, toType)); |
| } |
| |
| // Expressions.call is equivalent to: .build(); |
| output = Expressions.call(output, "build"); |
| |
| builder.add( |
| // Expressions.ifThen is equivalent to: |
| // if (condition) { |
| // c.output(output); |
| // } |
| Expressions.ifThen( |
| condition, |
| Expressions.makeGoto( |
| GotoExpressionKind.Sequence, |
| null, |
| Expressions.call( |
| processContextParam, |
| Types.lookupMethod(DoFn.ProcessContext.class, "output", Object.class), |
| output)))); |
| |
| CalcFn calcFn = new CalcFn(builder.toBlock().toString(), outputSchema); |
| |
| // validate generated code |
| calcFn.compile(); |
| |
| PCollection<Row> projectStream = upstream.apply(ParDo.of(calcFn)).setRowSchema(outputSchema); |
| |
| return projectStream; |
| } |
| } |
| |
| public int getLimitCountOfSortRel() { |
| if (input instanceof BeamSortRel) { |
| return ((BeamSortRel) input).getCount(); |
| } |
| |
| throw new RuntimeException("Could not get the limit count from a non BeamSortRel input."); |
| } |
| |
| public boolean isInputSortRelAndLimitOnly() { |
| return (input instanceof BeamSortRel) && ((BeamSortRel) input).isLimitOnly(); |
| } |
| |
| /** {@code CalcFn} is the executor for a {@link BeamCalcRel} step. */ |
| private static class CalcFn extends DoFn<Row, Row> { |
| private final String processElementBlock; |
| private final Schema outputSchema; |
| private transient @Nullable ScriptEvaluator se = null; |
| |
| public CalcFn(String processElementBlock, Schema outputSchema) { |
| this.processElementBlock = processElementBlock; |
| this.outputSchema = outputSchema; |
| } |
| |
| ScriptEvaluator compile() { |
| ScriptEvaluator se = new ScriptEvaluator(); |
| se.setParameters( |
| new String[] {outputSchemaParam.name, processContextParam.name, DataContext.ROOT.name}, |
| new Class[] { |
| (Class) outputSchemaParam.getType(), |
| (Class) processContextParam.getType(), |
| (Class) DataContext.ROOT.getType() |
| }); |
| try { |
| se.cook(processElementBlock); |
| } catch (CompileException e) { |
| throw new RuntimeException("Could not compile CalcFn: " + processElementBlock, e); |
| } |
| return se; |
| } |
| |
| @Setup |
| public void setup() { |
| this.se = compile(); |
| } |
| |
| @ProcessElement |
| public void processElement(ProcessContext c) { |
| assert se != null; |
| try { |
| se.evaluate(new Object[] {outputSchema, c, CONTEXT_INSTANCE}); |
| } catch (InvocationTargetException e) { |
| throw new RuntimeException( |
| "CalcFn failed to evaluate: " + processElementBlock, e.getCause()); |
| } |
| } |
| } |
| |
| private static final Map<TypeName, Type> rawTypeMap = |
| ImmutableMap.<TypeName, Type>builder() |
| .put(TypeName.BYTE, Byte.class) |
| .put(TypeName.INT16, Short.class) |
| .put(TypeName.INT32, Integer.class) |
| .put(TypeName.INT64, Long.class) |
| .put(TypeName.FLOAT, Float.class) |
| .put(TypeName.DOUBLE, Double.class) |
| .build(); |
| |
| private Expression castOutput(Expression value, FieldType toType) { |
| if (value.getType() == Object.class || !(value.getType() instanceof Class)) { |
| // fast copy path, just pass object through |
| return value; |
| } else if (CalciteUtils.isDateTimeType(toType) |
| && !Types.isAssignableFrom(ReadableInstant.class, (Class) value.getType())) { |
| return castOutputTime(value, toType); |
| |
| } else if (toType.getTypeName() == TypeName.DECIMAL |
| && !Types.isAssignableFrom(BigDecimal.class, (Class) value.getType())) { |
| return Expressions.new_(BigDecimal.class, value); |
| |
| } else if (((Class) value.getType()).isPrimitive() |
| || Types.isAssignableFrom(Number.class, (Class) value.getType())) { |
| Type rawType = rawTypeMap.get(toType.getTypeName()); |
| if (rawType != null) { |
| return Types.castIfNecessary(rawType, value); |
| } |
| } |
| return value; |
| } |
| |
| private Expression castOutputTime(Expression value, FieldType toType) { |
| Expression valueDateTime = value; |
| |
| // First, convert to millis |
| if (CalciteUtils.TIMESTAMP.typesEqual(toType)) { |
| if (value.getType() == java.sql.Timestamp.class) { |
| valueDateTime = Expressions.call(BuiltInMethod.TIMESTAMP_TO_LONG.method, valueDateTime); |
| } |
| } else if (CalciteUtils.TIME.typesEqual(toType)) { |
| if (value.getType() == java.sql.Time.class) { |
| valueDateTime = Expressions.call(BuiltInMethod.TIME_TO_INT.method, valueDateTime); |
| } |
| } else if (CalciteUtils.DATE.typesEqual(toType)) { |
| if (value.getType() == java.sql.Date.class) { |
| valueDateTime = Expressions.call(BuiltInMethod.DATE_TO_INT.method, valueDateTime); |
| } |
| valueDateTime = Expressions.multiply(valueDateTime, Expressions.constant(MILLIS_PER_DAY)); |
| } else { |
| throw new IllegalArgumentException("Unknown DateTime type " + toType); |
| } |
| |
| // Second, convert to joda DateTime |
| valueDateTime = |
| Expressions.new_( |
| DateTime.class, |
| valueDateTime, |
| Expressions.parameter(DateTimeZone.class, "org.joda.time.DateTimeZone.UTC")); |
| |
| // Third, make conversion conditional on non-null input. |
| if (!((Class) value.getType()).isPrimitive()) { |
| valueDateTime = |
| Expressions.condition( |
| Expressions.equal(value, Expressions.constant(null)), |
| Expressions.constant(null), |
| valueDateTime); |
| } |
| |
| return valueDateTime; |
| } |
| |
| private static class InputGetterImpl implements RexToLixTranslator.InputGetter { |
| private static final Map<TypeName, String> typeGetterMap = |
| ImmutableMap.<TypeName, String>builder() |
| .put(TypeName.BYTE, "getByte") |
| .put(TypeName.BYTES, "getBytes") |
| .put(TypeName.INT16, "getInt16") |
| .put(TypeName.INT32, "getInt32") |
| .put(TypeName.INT64, "getInt64") |
| .put(TypeName.DECIMAL, "getDecimal") |
| .put(TypeName.FLOAT, "getFloat") |
| .put(TypeName.DOUBLE, "getDouble") |
| .put(TypeName.STRING, "getString") |
| .put(TypeName.DATETIME, "getDateTime") |
| .put(TypeName.BOOLEAN, "getBoolean") |
| .put(TypeName.MAP, "getMap") |
| .put(TypeName.ARRAY, "getArray") |
| .put(TypeName.ROW, "getRow") |
| .build(); |
| |
| private static final Map<String, String> logicalTypeGetterMap = |
| ImmutableMap.<String, String>builder() |
| .put(DateType.IDENTIFIER, "getDateTime") |
| .put(TimeType.IDENTIFIER, "getDateTime") |
| .put(TimeWithLocalTzType.IDENTIFIER, "getDateTime") |
| .put(TimestampWithLocalTzType.IDENTIFIER, "getDateTime") |
| .put(CharType.IDENTIFIER, "getString") |
| .build(); |
| |
| private final Expression input; |
| private final Schema inputSchema; |
| |
| private InputGetterImpl(Expression input, Schema inputSchema) { |
| this.input = input; |
| this.inputSchema = inputSchema; |
| } |
| |
| @Override |
| public Expression field(BlockBuilder list, int index, Type storageType) { |
| if (index >= inputSchema.getFieldCount() || index < 0) { |
| throw new IllegalArgumentException("Unable to find field #" + index); |
| } |
| |
| final Expression expression = list.append("current", input); |
| if (storageType == Object.class) { |
| return Expressions.convert_( |
| Expressions.call(expression, "getValue", Expressions.constant(index)), Object.class); |
| } |
| FieldType fromType = inputSchema.getField(index).getType(); |
| String getter; |
| if (fromType.getTypeName().isLogicalType()) { |
| getter = logicalTypeGetterMap.get(fromType.getLogicalType().getIdentifier()); |
| } else { |
| getter = typeGetterMap.get(fromType.getTypeName()); |
| } |
| if (getter == null) { |
| throw new IllegalArgumentException("Unable to get " + fromType.getTypeName()); |
| } |
| Expression field = Expressions.call(expression, getter, Expressions.constant(index)); |
| if (fromType.getTypeName().isLogicalType()) { |
| field = Expressions.call(field, "getMillis"); |
| String logicalId = fromType.getLogicalType().getIdentifier(); |
| if (logicalId.equals(TimeType.IDENTIFIER)) { |
| field = Expressions.convert_(field, int.class); |
| } else if (logicalId.equals(DateType.IDENTIFIER)) { |
| field = |
| Expressions.convert_( |
| Expressions.modulo(field, Expressions.constant(MILLIS_PER_DAY)), int.class); |
| } else if (!logicalId.equals(CharType.IDENTIFIER)) { |
| throw new IllegalArgumentException( |
| "Unknown LogicalType " + fromType.getLogicalType().getIdentifier()); |
| } |
| } else if (CalciteUtils.isDateTimeType(fromType)) { |
| field = Expressions.call(field, "getMillis"); |
| } else if (fromType.getTypeName().isCompositeType() |
| || (fromType.getTypeName().isCollectionType() |
| && fromType.getCollectionElementType().getTypeName().isCompositeType())) { |
| field = Expressions.call(WrappedList.class, "of", field); |
| } |
| return field; |
| } |
| } |
| |
| private static final DataContext CONTEXT_INSTANCE = new SlimDataContext(); |
| |
| private static class SlimDataContext implements DataContext { |
| @Override |
| public SchemaPlus getRootSchema() { |
| return null; |
| } |
| |
| @Override |
| public JavaTypeFactory getTypeFactory() { |
| return null; |
| } |
| |
| @Override |
| public QueryProvider getQueryProvider() { |
| return null; |
| } |
| |
| /* DataContext.get is used to fetch "global" state inside the generated code */ |
| @Override |
| public Object get(String name) { |
| if (name.equals(DataContext.Variable.UTC_TIMESTAMP.camelName) |
| || name.equals(DataContext.Variable.CURRENT_TIMESTAMP.camelName) |
| || name.equals(DataContext.Variable.LOCAL_TIMESTAMP.camelName)) { |
| return System.currentTimeMillis(); |
| } |
| return null; |
| } |
| } |
| |
| /** WrappedList translates {@code Row} and {@code List} on access. */ |
| public static class WrappedList extends AbstractList<Object> { |
| |
| private final List<Object> list; |
| |
| private WrappedList(List<Object> list) { |
| this.list = list; |
| } |
| |
| public static List<Object> of(List list) { |
| if (list instanceof WrappedList) { |
| return list; |
| } |
| return new WrappedList(list); |
| } |
| |
| public static List<Object> of(Row row) { |
| return new WrappedList(row.getValues()); |
| } |
| |
| @Override |
| public Object get(int index) { |
| Object obj = list.get(index); |
| if (obj instanceof Row) { |
| obj = of((Row) obj); |
| } else if (obj instanceof List) { |
| obj = of((List) obj); |
| } |
| return obj; |
| } |
| |
| @Override |
| public int size() { |
| return list.size(); |
| } |
| } |
| } |