| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.druid.sql.calcite.aggregation; |
| |
| import com.google.common.base.Preconditions; |
| import com.google.common.collect.ImmutableList; |
| import com.google.common.collect.Iterables; |
| import org.apache.druid.java.util.common.IAE; |
| import org.apache.druid.java.util.common.ISE; |
| import org.apache.druid.query.aggregation.AggregatorFactory; |
| import org.apache.druid.query.aggregation.FilteredAggregatorFactory; |
| import org.apache.druid.query.aggregation.PostAggregator; |
| import org.apache.druid.query.filter.AndDimFilter; |
| import org.apache.druid.query.filter.DimFilter; |
| import org.apache.druid.segment.column.RowSignature; |
| import org.apache.druid.sql.calcite.filtration.Filtration; |
| import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; |
| |
| import javax.annotation.Nullable; |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Objects; |
| import java.util.Set; |
| |
| public class Aggregation |
| { |
| private final List<AggregatorFactory> aggregatorFactories; |
| private final PostAggregator postAggregator; |
| |
| private Aggregation( |
| final List<AggregatorFactory> aggregatorFactories, |
| final PostAggregator postAggregator |
| ) |
| { |
| this.aggregatorFactories = Preconditions.checkNotNull(aggregatorFactories, "aggregatorFactories"); |
| this.postAggregator = postAggregator; |
| |
| if (aggregatorFactories.isEmpty()) { |
| Preconditions.checkArgument(postAggregator != null, "postAggregator must be present if there are no aggregators"); |
| } |
| |
| if (postAggregator == null) { |
| Preconditions.checkArgument(aggregatorFactories.size() == 1, "aggregatorFactories.size == 1"); |
| } else { |
| // Verify that there are no "useless" fields in the aggregatorFactories. |
| // Don't verify that the PostAggregator inputs are all present; they might not be. |
| final Set<String> dependentFields = postAggregator.getDependentFields(); |
| for (AggregatorFactory aggregatorFactory : aggregatorFactories) { |
| if (!dependentFields.contains(aggregatorFactory.getName())) { |
| throw new IAE("Unused field[%s] in Aggregation", aggregatorFactory.getName()); |
| } |
| } |
| } |
| |
| // Verify that all "internal" aggregator names are prefixed by the output name of this aggregation. |
| // This is a sanity check to make sure callers are behaving as they should be. |
| final String name = postAggregator != null |
| ? postAggregator.getName() |
| : Iterables.getOnlyElement(aggregatorFactories).getName(); |
| |
| for (AggregatorFactory aggregatorFactory : aggregatorFactories) { |
| if (!aggregatorFactory.getName().startsWith(name)) { |
| throw new IAE("Aggregator[%s] not prefixed under[%s]", aggregatorFactory.getName(), name); |
| } |
| } |
| } |
| |
| |
| public static Aggregation create(final AggregatorFactory aggregatorFactory) |
| { |
| return new Aggregation( |
| ImmutableList.of(aggregatorFactory), |
| null |
| ); |
| } |
| |
| public static Aggregation create(final PostAggregator postAggregator) |
| { |
| return new Aggregation(Collections.emptyList(), postAggregator); |
| } |
| |
| public static Aggregation create( |
| final List<AggregatorFactory> aggregatorFactories, |
| final PostAggregator postAggregator |
| ) |
| { |
| return new Aggregation(aggregatorFactories, postAggregator); |
| } |
| |
| public List<String> getRequiredColumns() |
| { |
| Set<String> columns = new HashSet<>(); |
| for (AggregatorFactory agg : aggregatorFactories) { |
| columns.addAll(agg.requiredFields()); |
| } |
| if (postAggregator != null) { |
| columns.addAll(postAggregator.getDependentFields()); |
| } |
| return ImmutableList.copyOf(columns); |
| } |
| |
| public List<AggregatorFactory> getAggregatorFactories() |
| { |
| return aggregatorFactories; |
| } |
| |
| @Nullable |
| public PostAggregator getPostAggregator() |
| { |
| return postAggregator; |
| } |
| |
| public String getOutputName() |
| { |
| return postAggregator != null |
| ? postAggregator.getName() |
| : Iterables.getOnlyElement(aggregatorFactories).getName(); |
| } |
| |
| public Aggregation filter( |
| final RowSignature rowSignature, |
| final VirtualColumnRegistry virtualColumnRegistry, |
| final DimFilter filter |
| ) |
| { |
| if (filter == null) { |
| return this; |
| } |
| |
| |
| if (postAggregator != null) { |
| // Verify that this Aggregation contains all input to its postAggregator. |
| // If not, this "filter" call won't work right. |
| final Set<String> dependentFields = postAggregator.getDependentFields(); |
| final Set<String> aggregatorNames = new HashSet<>(); |
| for (AggregatorFactory aggregatorFactory : aggregatorFactories) { |
| aggregatorNames.add(aggregatorFactory.getName()); |
| } |
| for (String field : dependentFields) { |
| if (!aggregatorNames.contains(field)) { |
| throw new ISE("Cannot filter an Aggregation that does not contain its inputs: %s", this); |
| } |
| } |
| } |
| |
| final DimFilter baseOptimizedFilter = Filtration.create(filter) |
| .optimizeFilterOnly(virtualColumnRegistry.getFullRowSignature()) |
| .getDimFilter(); |
| |
| final List<AggregatorFactory> newAggregators = new ArrayList<>(); |
| for (AggregatorFactory agg : aggregatorFactories) { |
| if (agg instanceof FilteredAggregatorFactory) { |
| final FilteredAggregatorFactory filteredAgg = (FilteredAggregatorFactory) agg; |
| newAggregators.add( |
| new FilteredAggregatorFactory( |
| filteredAgg.getAggregator(), |
| Filtration.create(new AndDimFilter(ImmutableList.of(filteredAgg.getFilter(), baseOptimizedFilter))) |
| .optimizeFilterOnly(virtualColumnRegistry.getFullRowSignature()) |
| .getDimFilter() |
| ) |
| ); |
| } else { |
| newAggregators.add(new FilteredAggregatorFactory(agg, baseOptimizedFilter)); |
| } |
| } |
| |
| return new Aggregation(newAggregators, postAggregator); |
| } |
| |
| @Override |
| public boolean equals(final Object o) |
| { |
| if (this == o) { |
| return true; |
| } |
| if (o == null || getClass() != o.getClass()) { |
| return false; |
| } |
| final Aggregation that = (Aggregation) o; |
| return Objects.equals(aggregatorFactories, that.aggregatorFactories) && |
| Objects.equals(postAggregator, that.postAggregator); |
| } |
| |
| @Override |
| public int hashCode() |
| { |
| return Objects.hash(aggregatorFactories, postAggregator); |
| } |
| |
| @Override |
| public String toString() |
| { |
| return "Aggregation{" + |
| "aggregatorFactories=" + aggregatorFactories + |
| ", postAggregator=" + postAggregator + |
| '}'; |
| } |
| } |