Fix some bugs of aggregation in TableModel
diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiIDsWithAttributesTableIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiIDsWithAttributesTableIT.java
index feda477..490bb17 100644
--- a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiIDsWithAttributesTableIT.java
+++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiIDsWithAttributesTableIT.java
@@ -714,4 +714,25 @@
+ "ORDER BY time, t1.device, t2.device";
tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
}
+
+ // ========== Aggregation Test =========
+ @Test
+ public void globalAggregationTest() {
+ String[] expectedHeader = new String[] {"_col0"};
+ String[] retArray = new String[] {"30,"};
+
+ String sql = "SELECT count(num+1) from table0";
+ tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+ }
+
+ @Test
+ public void countStarTest() {
+ String[] expectedHeader = new String[] {"_col0", "_col1"};
+ String[] retArray = new String[] {"1,1,"};
+
+ String sql = "select count(*),count(t1) from (select avg(num+1) as t1 from table0)";
+ tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+
+ // TODO select count(*),count(t1) from (select avg(num+1) as t1 from table0) where time < 0
+ }
}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
index f5157f1..e9205e3 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
@@ -55,10 +55,14 @@
List<Expression> inputExpressions,
Map<String, String> inputAttributes,
boolean ascending) {
- return isMultiInputAggregation(aggregationType)
- ? createBuiltinMultiInputAccumulator(aggregationType, inputDataTypes)
- : createBuiltinSingleInputAccumulator(
- aggregationType, inputDataTypes.get(0), inputExpressions, inputAttributes, ascending);
+ switch (aggregationType) {
+ case COUNT:
+ return new CountAccumulator();
+ case AVG:
+ return new AvgAccumulator(inputDataTypes.get(0));
+ default:
+ throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType);
+ }
}
public static boolean isMultiInputAggregation(TAggregationType aggregationType) {
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Aggregator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Aggregator.java
index cbd465e..ada8ce9 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Aggregator.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Aggregator.java
@@ -20,12 +20,14 @@
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.TsBlock;
+import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
import java.util.List;
import java.util.OptionalInt;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;
+import static org.apache.iotdb.db.queryengine.execution.operator.source.relational.TableScanOperator.TIME_COLUMN_TEMPLATE;
public class Aggregator {
private final Accumulator accumulator;
@@ -55,11 +57,18 @@
}
public void processBlock(TsBlock block) {
+ Column[] arguments = block.getColumns(inputChannels);
+
+ // process count(*)
+ if (arguments.length == 0) {
+ arguments =
+ new Column[] {new RunLengthEncodedColumn(TIME_COLUMN_TEMPLATE, block.getPositionCount())};
+ }
+
if (step.isInputRaw()) {
- Column[] arguments = block.getColumns(inputChannels);
accumulator.addInput(arguments);
} else {
- accumulator.addIntermediate(block.getColumn(inputChannels[0]));
+ accumulator.addIntermediate(arguments[0]);
}
}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
index 6f109e2..971e1c7 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
@@ -1208,12 +1208,17 @@
symbol ->
aggregatorBuilder.add(
buildAggregator(
- childLayout, aggregationMap.get(symbol), node.getStep(), typeProvider)));
+ childLayout,
+ symbol,
+ aggregationMap.get(symbol),
+ node.getStep(),
+ typeProvider)));
return new AggregationOperator(context, child, aggregatorBuilder.build());
}
private Aggregator buildAggregator(
Map<Symbol, Integer> childLayout,
+ Symbol symbol,
AggregationNode.Aggregation aggregation,
AggregationNode.Step step,
TypeProvider typeProvider) {
@@ -1245,7 +1250,7 @@
return new Aggregator(
accumulator,
step,
- getTSDataType(aggregation.getResolvedFunction().getSignature().getReturnType()),
+ getTSDataType(typeProvider.getTableModelType(symbol)),
argumentChannels,
OptionalInt.empty());
}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
index 72f0cda..e56675e 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
@@ -31,7 +31,7 @@
import java.util.stream.Collectors;
import static org.apache.tsfile.read.common.type.DoubleType.DOUBLE;
-import static org.apache.tsfile.read.common.type.IntType.INT32;
+import static org.apache.tsfile.read.common.type.LongType.INT64;
public enum TableBuiltinAggregationFunction {
SUM("sum"),
@@ -109,8 +109,10 @@
}
public static List<Type> getIntermediateTypes(String name, List<Type> originalArgumentTypes) {
- if (AVG.functionName.equalsIgnoreCase(name)) {
- return ImmutableList.of(DOUBLE, INT32);
+ if (COUNT.functionName.equalsIgnoreCase(name)) {
+ return ImmutableList.of(INT64);
+ } else if (AVG.functionName.equalsIgnoreCase(name)) {
+ return ImmutableList.of(DOUBLE, INT64);
} else {
return ImmutableList.copyOf(originalArgumentTypes);
}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationNode.java
index 9ab5ea8..898d837 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationNode.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationNode.java
@@ -37,8 +37,8 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
-import java.util.HashMap;
import java.util.HashSet;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -292,7 +292,7 @@
public static AggregationNode deserialize(ByteBuffer byteBuffer) {
int size = ReadWriteIOUtils.readInt(byteBuffer);
- final Map<Symbol, Aggregation> aggregations = new HashMap<>(size);
+ final Map<Symbol, Aggregation> aggregations = new LinkedHashMap<>(size);
while (size-- > 0) {
aggregations.put(Symbol.deserialize(byteBuffer), Aggregation.deserialize(byteBuffer));
}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationTableScanNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationTableScanNode.java
index b33e9b3..fe40770 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationTableScanNode.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationTableScanNode.java
@@ -38,6 +38,7 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -489,7 +490,7 @@
}
size = ReadWriteIOUtils.readInt(byteBuffer);
- final Map<Symbol, AggregationNode.Aggregation> aggregations = new HashMap<>(size);
+ final Map<Symbol, AggregationNode.Aggregation> aggregations = new LinkedHashMap<>(size);
while (size-- > 0) {
aggregations.put(
Symbol.deserialize(byteBuffer), AggregationNode.Aggregation.deserialize(byteBuffer));
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
index c256e53..34965df 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
@@ -27,7 +27,7 @@
import org.apache.tsfile.read.common.type.Type;
import org.apache.tsfile.utils.Pair;
-import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -46,8 +46,8 @@
*/
public static Pair<AggregationNode, AggregationNode> split(
AggregationNode node, SymbolAllocator symbolAllocator, QueryId queryId) {
- Map<Symbol, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
- Map<Symbol, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
+ Map<Symbol, AggregationNode.Aggregation> intermediateAggregation = new LinkedHashMap<>();
+ Map<Symbol, AggregationNode.Aggregation> finalAggregation = new LinkedHashMap<>();
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
AggregationNode.Aggregation originalAggregation = entry.getValue();
ResolvedFunction resolvedFunction = originalAggregation.getResolvedFunction();
@@ -116,8 +116,8 @@
*/
public static Pair<AggregationNode, AggregationTableScanNode> split(
AggregationTableScanNode node, SymbolAllocator symbolAllocator, QueryId queryId) {
- Map<Symbol, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
- Map<Symbol, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
+ Map<Symbol, AggregationNode.Aggregation> intermediateAggregation = new LinkedHashMap<>();
+ Map<Symbol, AggregationNode.Aggregation> finalAggregation = new LinkedHashMap<>();
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
AggregationNode.Aggregation originalAggregation = entry.getValue();
ResolvedFunction resolvedFunction = originalAggregation.getResolvedFunction();