| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to you under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.calcite.adapter.pig; |
| |
| import org.apache.calcite.plan.RelOptCluster; |
| import org.apache.calcite.plan.RelOptTable; |
| import org.apache.calcite.plan.RelTraitSet; |
| import org.apache.calcite.rel.RelNode; |
| import org.apache.calcite.rel.core.Aggregate; |
| import org.apache.calcite.rel.core.AggregateCall; |
| import org.apache.calcite.rel.type.RelDataTypeField; |
| import org.apache.calcite.util.ImmutableBitSet; |
| |
| import org.apache.pig.scripting.Pig; |
| |
| import com.google.common.collect.ImmutableList; |
| |
| import java.util.ArrayList; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Set; |
| |
| /** Implementation of {@link org.apache.calcite.rel.core.Aggregate} in |
| * {@link PigRel#CONVENTION Pig calling convention}. */ |
| public class PigAggregate extends Aggregate implements PigRel { |
| |
| public static final String DISTINCT_FIELD_SUFFIX = "_DISTINCT"; |
| |
| /** Creates a PigAggregate. */ |
| public PigAggregate(RelOptCluster cluster, RelTraitSet traitSet, |
| RelNode input, ImmutableBitSet groupSet, |
| List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) { |
| super(cluster, traitSet, ImmutableList.of(), input, groupSet, groupSets, aggCalls); |
| assert getConvention() == PigRel.CONVENTION; |
| } |
| |
| @Deprecated // to be removed before 2.0 |
| public PigAggregate(RelOptCluster cluster, RelTraitSet traitSet, |
| RelNode input, boolean indicator, ImmutableBitSet groupSet, |
| List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) { |
| this(cluster, traitSet, input, groupSet, groupSets, aggCalls); |
| checkIndicator(indicator); |
| } |
| |
| @Override public Aggregate copy(RelTraitSet traitSet, RelNode input, |
| ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, |
| List<AggregateCall> aggCalls) { |
| return new PigAggregate(input.getCluster(), traitSet, input, groupSet, |
| groupSets, aggCalls); |
| } |
| |
| @Override public void implement(Implementor implementor) { |
| implementor.visitChild(0, getInput()); |
| implementor.addStatement(getPigAggregateStatement(implementor)); |
| } |
| |
| /** |
| * Generates a GROUP BY statement, followed by an optional FOREACH statement |
| * for all aggregate functions used. e.g. |
| * <pre> |
| * {@code |
| * A = GROUP A BY owner; |
| * A = FOREACH A GENERATE group, SUM(A.pet_num); |
| * } |
| * </pre> |
| */ |
| private String getPigAggregateStatement(Implementor implementor) { |
| return getPigGroupBy(implementor) + '\n' + getPigForEachGenerate(implementor); |
| } |
| |
| /** |
| * Override this method so it looks down the tree to find the table this node |
| * is acting on. |
| */ |
| @Override public RelOptTable getTable() { |
| return getInput().getTable(); |
| } |
| |
| /** |
| * Generates the GROUP BY statement, e.g. |
| * <code>A = GROUP A BY (f1, f2);</code> |
| */ |
| private String getPigGroupBy(Implementor implementor) { |
| final String relAlias = implementor.getPigRelationAlias(this); |
| final List<RelDataTypeField> allFields = getInput().getRowType().getFieldList(); |
| final List<Integer> groupedFieldIndexes = groupSet.asList(); |
| if (groupedFieldIndexes.size() < 1) { |
| return relAlias + " = GROUP " + relAlias + " ALL;"; |
| } else { |
| final List<String> groupedFieldNames = new ArrayList<>(groupedFieldIndexes.size()); |
| for (int fieldIndex : groupedFieldIndexes) { |
| groupedFieldNames.add(allFields.get(fieldIndex).getName()); |
| } |
| return relAlias + " = GROUP " + relAlias + " BY (" |
| + String.join(", ", groupedFieldNames) + ");"; |
| } |
| } |
| |
| /** |
| * Generates a FOREACH statement containing invocation of aggregate functions |
| * and projection of grouped fields. e.g. |
| * <code>A = FOREACH A GENERATE group, SUM(A.pet_num);</code> |
| * @see Pig documentation for special meaning of the "group" field after GROUP |
| * BY. |
| */ |
| private String getPigForEachGenerate(Implementor implementor) { |
| final String relAlias = implementor.getPigRelationAlias(this); |
| final String generateCall = getPigGenerateCall(implementor); |
| final List<String> distinctCalls = getDistinctCalls(implementor); |
| return relAlias + " = FOREACH " + relAlias + " {\n" |
| + String.join(";\n", distinctCalls) + generateCall + "\n};"; |
| } |
| |
| private String getPigGenerateCall(Implementor implementor) { |
| final List<Integer> groupedFieldIndexes = groupSet.asList(); |
| Set<String> groupFields = new HashSet<>(groupedFieldIndexes.size()); |
| for (int fieldIndex : groupedFieldIndexes) { |
| final String fieldName = getInputFieldName(fieldIndex); |
| // Pig appends group field name if grouping by multiple fields |
| final String groupField = (groupedFieldIndexes.size() == 1 ? "group" : ("group." + fieldName)) |
| + " AS " + fieldName; |
| |
| groupFields.add(groupField); |
| } |
| final List<String> pigAggCalls = getPigAggregateCalls(implementor); |
| List<String> allFields = new ArrayList<>(groupFields.size() + pigAggCalls.size()); |
| allFields.addAll(groupFields); |
| allFields.addAll(pigAggCalls); |
| return " GENERATE " + String.join(", ", allFields) + ';'; |
| } |
| |
| private List<String> getPigAggregateCalls(Implementor implementor) { |
| final String relAlias = implementor.getPigRelationAlias(this); |
| final List<String> result = new ArrayList<>(aggCalls.size()); |
| for (AggregateCall ac : aggCalls) { |
| result.add(getPigAggregateCall(relAlias, ac)); |
| } |
| return result; |
| } |
| |
| private String getPigAggregateCall(String relAlias, AggregateCall aggCall) { |
| final PigAggFunction aggFunc = toPigAggFunc(aggCall); |
| final String alias = aggCall.getName(); |
| final String fields = String.join(", ", getArgNames(relAlias, aggCall)); |
| return aggFunc.name() + "(" + fields + ") AS " + alias; |
| } |
| |
| private static PigAggFunction toPigAggFunc(AggregateCall aggCall) { |
| return PigAggFunction.valueOf(aggCall.getAggregation().getKind(), |
| aggCall.getArgList().size() < 1); |
| } |
| |
| private List<String> getArgNames(String relAlias, AggregateCall aggCall) { |
| final List<String> result = new ArrayList<>(aggCall.getArgList().size()); |
| for (int fieldIndex : aggCall.getArgList()) { |
| result.add(getInputFieldNameForAggCall(relAlias, aggCall, fieldIndex)); |
| } |
| return result; |
| } |
| |
| private String getInputFieldNameForAggCall(String relAlias, AggregateCall aggCall, |
| int fieldIndex) { |
| final String inputField = getInputFieldName(fieldIndex); |
| return aggCall.isDistinct() ? (inputField + DISTINCT_FIELD_SUFFIX) |
| : (relAlias + '.' + inputField); |
| } |
| |
| /** |
| * Returns the calls to aggregate functions that have the {@code DISTINT} flag. |
| * |
| * <p>An aggregate function call like <code>COUNT(DISTINCT COL)</code> in Pig |
| * is achieved via two statements in a {@code FOREACH} that follows a |
| * {@code GROUP} statement: |
| * |
| * <blockquote> |
| * <code> |
| * TABLE = GROUP TABLE ALL;<br> |
| * TABLE = FOREACH TABLE {<br> |
| * <b>COL.DISTINCT = DISTINCT COL;<br> |
| * GENERATE COUNT(COL.DISTINCT) AS C;</b><br> |
| * }</code> |
| * </blockquote> |
| */ |
| private List<String> getDistinctCalls(Implementor implementor) { |
| final String relAlias = implementor.getPigRelationAlias(this); |
| final List<String> result = new ArrayList<>(); |
| for (AggregateCall aggCall : aggCalls) { |
| if (aggCall.isDistinct()) { |
| for (int fieldIndex : aggCall.getArgList()) { |
| String fieldName = getInputFieldName(fieldIndex); |
| result.add(" " + fieldName + DISTINCT_FIELD_SUFFIX + " = DISTINCT " |
| + relAlias + '.' + fieldName + ";\n"); |
| } |
| } |
| } |
| return result; |
| } |
| |
| private String getInputFieldName(int fieldIndex) { |
| return getInput().getRowType().getFieldList().get(fieldIndex).getName(); |
| } |
| } |