blob: 7c9a431b2963597644d9b6d9c237a26e9249e88d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.physical.impl.agg;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.apache.drill.common.expression.ExpressionPosition;
import org.apache.drill.common.expression.FieldReference;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.logical.data.NamedExpression;
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.exec.physical.config.HashAggregate;
import org.apache.drill.exec.physical.impl.MockRecordBatch;
import org.apache.drill.exec.planner.physical.AggPrelBase;
import org.apache.drill.exec.record.metadata.SchemaBuilder;
import org.apache.drill.exec.record.metadata.TupleMetadata;
import org.apache.drill.test.PhysicalOpUnitTestBase;
import org.apache.drill.exec.physical.rowSet.RowSet;
import org.apache.drill.exec.physical.rowSet.RowSetBuilder;
import org.junit.Before;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import static org.apache.drill.exec.ExecConstants.HASHAGG_NUM_PARTITIONS_KEY;
public class TestHashAggBatch extends PhysicalOpUnitTestBase {
public static final String FIRST_NAME_COL = "firstname";
public static final String LAST_NAME_COL = "lastname";
public static final String STUFF_COL = "stuff";
public static final String TOTAL_STUFF_COL = "totalstuff";
public static final List<String> FIRST_NAMES = ImmutableList.of(
"Strawberry",
"Banana",
"Mango",
"Grape");
public static final List<String> LAST_NAMES = ImmutableList.of(
"Red",
"Green",
"Blue",
"Purple");
public static final TupleMetadata INT_OUTPUT_SCHEMA = new SchemaBuilder()
.add(FIRST_NAME_COL, TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED)
.add(LAST_NAME_COL, TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED)
.add(TOTAL_STUFF_COL, TypeProtos.MinorType.BIGINT, TypeProtos.DataMode.OPTIONAL)
.buildSchema();
// TODO remove this in order to test multiple partitions
@Before
public void setupSimpleSingleBatchSumTestPhase1of2() {
operatorFixture.getOptionManager().setLocalOption(HASHAGG_NUM_PARTITIONS_KEY, 1);
}
@Test
public void simpleSingleBatchSumTestPhase1of2() throws Exception {
batchSumTest(100, Integer.MAX_VALUE, AggPrelBase.OperatorPhase.PHASE_1of2);
}
@Test
public void simpleMultiBatchSumTestPhase1of2() throws Exception {
batchSumTest(100, 100, AggPrelBase.OperatorPhase.PHASE_1of2);
}
@Test
public void simpleSingleBatchSumTestPhase1of1() throws Exception {
batchSumTest(100, Integer.MAX_VALUE, AggPrelBase.OperatorPhase.PHASE_1of1);
}
@Test
public void simpleMultiBatchSumTestPhase1of1() throws Exception {
batchSumTest(100, 100, AggPrelBase.OperatorPhase.PHASE_1of1);
}
@Test
public void simpleSingleBatchSumTestPhase2of2() throws Exception {
batchSumTest(100, Integer.MAX_VALUE, AggPrelBase.OperatorPhase.PHASE_2of2);
}
@Test
public void simpleMultiBatchSumTestPhase2of2() throws Exception {
batchSumTest(100, 100, AggPrelBase.OperatorPhase.PHASE_2of2);
}
private void batchSumTest(int totalCount, int maxInputBatchSize, AggPrelBase.OperatorPhase phase) throws Exception {
final HashAggregate hashAggregate = createHashAggPhysicalOperator(phase);
final List<RowSet> inputRowSets = buildInputRowSets(TypeProtos.MinorType.INT, TypeProtos.DataMode.REQUIRED,
totalCount, maxInputBatchSize);
final MockRecordBatch.Builder rowSetBatchBuilder = new MockRecordBatch.Builder();
inputRowSets.forEach(rowSet -> rowSetBatchBuilder.sendData(rowSet));
final MockRecordBatch inputRowSetBatch = rowSetBatchBuilder.build(fragContext);
final RowSet expectedRowSet = buildIntExpectedRowSet(totalCount);
opTestBuilder()
.physicalOperator(hashAggregate)
.combineOutputBatches()
.unordered()
.addUpstreamBatch(inputRowSetBatch)
.addExpectedResult(expectedRowSet)
.go();
}
private HashAggregate createHashAggPhysicalOperator(AggPrelBase.OperatorPhase phase) {
final List<NamedExpression> keyExpressions = Lists.newArrayList(
new NamedExpression(SchemaPath.getSimplePath(FIRST_NAME_COL), new FieldReference(FIRST_NAME_COL)),
new NamedExpression(SchemaPath.getSimplePath(LAST_NAME_COL), new FieldReference(LAST_NAME_COL)));
final List<NamedExpression> aggExpressions = Lists.newArrayList(
new NamedExpression(
new FunctionCall("sum", ImmutableList.of(SchemaPath.getSimplePath(STUFF_COL)),
new ExpressionPosition(null, 0)),
new FieldReference(TOTAL_STUFF_COL)));
return new HashAggregate(
null,
phase,
keyExpressions,
aggExpressions,
0.0f);
}
private TupleMetadata buildInputSchema(TypeProtos.MinorType minorType, TypeProtos.DataMode dataMode) {
return new SchemaBuilder()
.add(FIRST_NAME_COL, TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED)
.add(LAST_NAME_COL, TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED)
.add(STUFF_COL, minorType, dataMode)
.buildSchema();
}
private List<RowSet> buildInputRowSets(final TypeProtos.MinorType minorType,
final TypeProtos.DataMode dataMode,
final int dataCount,
final int maxBatchSize) {
Preconditions.checkArgument(dataCount > 0);
Preconditions.checkArgument(maxBatchSize > 0);
List<RowSet> inputRowSets = new ArrayList<>();
int currentBatchSize = 0;
RowSetBuilder inputRowSetBuilder = null;
for (int multiplier = 1, firstNameIndex = 0; firstNameIndex < FIRST_NAMES.size(); firstNameIndex++) {
final String firstName = FIRST_NAMES.get(firstNameIndex);
for (int lastNameIndex = 0; lastNameIndex < LAST_NAMES.size(); lastNameIndex++, multiplier++) {
final String lastName = LAST_NAMES.get(lastNameIndex);
for (int index = 1; index <= dataCount; index++) {
final int num = index * multiplier;
if (currentBatchSize == 0) {
final TupleMetadata inputSchema = buildInputSchema(minorType, dataMode);
inputRowSetBuilder = new RowSetBuilder(operatorFixture.allocator(), inputSchema);
}
inputRowSetBuilder.addRow(firstName, lastName, num);
currentBatchSize++;
if (currentBatchSize == maxBatchSize) {
final RowSet rowSet = inputRowSetBuilder.build();
inputRowSets.add(rowSet);
currentBatchSize = 0;
}
}
}
}
if (currentBatchSize != 0) {
inputRowSets.add(inputRowSetBuilder.build());
}
return inputRowSets;
}
private RowSet buildIntExpectedRowSet(final int dataCount) {
final RowSetBuilder expectedRowSetBuilder = new RowSetBuilder(operatorFixture.allocator(), INT_OUTPUT_SCHEMA);
for (int multiplier = 1, firstNameIndex = 0; firstNameIndex < FIRST_NAMES.size(); firstNameIndex++) {
final String firstName = FIRST_NAMES.get(firstNameIndex);
for (int lastNameIndex = 0; lastNameIndex < LAST_NAMES.size(); lastNameIndex++, multiplier++) {
final String lastName = LAST_NAMES.get(lastNameIndex);
final long total = ((dataCount * (dataCount + 1)) / 2) * multiplier;
expectedRowSetBuilder.addRow(firstName, lastName, total);
}
}
return expectedRowSetBuilder.build();
}
}