/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.drill.exec.physical.impl.agg;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.apache.drill.common.expression.ExpressionPosition;
import org.apache.drill.common.expression.FieldReference;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.logical.data.NamedExpression;
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.exec.physical.config.HashAggregate;
import org.apache.drill.exec.physical.impl.MockRecordBatch;
import org.apache.drill.exec.planner.physical.AggPrelBase;
import org.apache.drill.exec.record.metadata.SchemaBuilder;
import org.apache.drill.exec.record.metadata.TupleMetadata;
import org.apache.drill.test.PhysicalOpUnitTestBase;
import org.apache.drill.exec.physical.rowSet.RowSet;
import org.apache.drill.exec.physical.rowSet.RowSetBuilder;
import org.junit.Before;
import org.junit.Test;

import java.util.ArrayList;
import java.util.List;

import static org.apache.drill.exec.ExecConstants.HASHAGG_NUM_PARTITIONS_KEY;

public class TestHashAggBatch extends PhysicalOpUnitTestBase {
  public static final String FIRST_NAME_COL = "firstname";
  public static final String LAST_NAME_COL = "lastname";
  public static final String STUFF_COL = "stuff";
  public static final String TOTAL_STUFF_COL = "totalstuff";

  public static final List<String> FIRST_NAMES = ImmutableList.of(
    "Strawberry",
    "Banana",
    "Mango",
    "Grape");

  public static final List<String> LAST_NAMES = ImmutableList.of(
    "Red",
    "Green",
    "Blue",
    "Purple");

  public static final TupleMetadata INT_OUTPUT_SCHEMA = new SchemaBuilder()
    .add(FIRST_NAME_COL, TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED)
    .add(LAST_NAME_COL, TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED)
    .add(TOTAL_STUFF_COL, TypeProtos.MinorType.BIGINT, TypeProtos.DataMode.OPTIONAL)
    .buildSchema();

  // TODO remove this in order to test multiple partitions
  @Before
  public void setupSimpleSingleBatchSumTestPhase1of2() {
    operatorFixture.getOptionManager().setLocalOption(HASHAGG_NUM_PARTITIONS_KEY, 1);
  }

  @Test
  public void simpleSingleBatchSumTestPhase1of2() throws Exception {
    batchSumTest(100, Integer.MAX_VALUE, AggPrelBase.OperatorPhase.PHASE_1of2);
  }

  @Test
  public void simpleMultiBatchSumTestPhase1of2() throws Exception {
    batchSumTest(100, 100, AggPrelBase.OperatorPhase.PHASE_1of2);
  }

  @Test
  public void simpleSingleBatchSumTestPhase1of1() throws Exception {
    batchSumTest(100, Integer.MAX_VALUE, AggPrelBase.OperatorPhase.PHASE_1of1);
  }

  @Test
  public void simpleMultiBatchSumTestPhase1of1() throws Exception {
    batchSumTest(100, 100, AggPrelBase.OperatorPhase.PHASE_1of1);
  }

  @Test
  public void simpleSingleBatchSumTestPhase2of2() throws Exception {
    batchSumTest(100, Integer.MAX_VALUE, AggPrelBase.OperatorPhase.PHASE_2of2);
  }

  @Test
  public void simpleMultiBatchSumTestPhase2of2() throws Exception {
    batchSumTest(100, 100, AggPrelBase.OperatorPhase.PHASE_2of2);
  }

  private void batchSumTest(int totalCount, int maxInputBatchSize, AggPrelBase.OperatorPhase phase) throws Exception {
    final HashAggregate hashAggregate = createHashAggPhysicalOperator(phase);
    final List<RowSet> inputRowSets = buildInputRowSets(TypeProtos.MinorType.INT, TypeProtos.DataMode.REQUIRED,
      totalCount, maxInputBatchSize);

    final MockRecordBatch.Builder rowSetBatchBuilder = new MockRecordBatch.Builder();
    inputRowSets.forEach(rowSet -> rowSetBatchBuilder.sendData(rowSet));
    final MockRecordBatch inputRowSetBatch = rowSetBatchBuilder.build(fragContext);

    final RowSet expectedRowSet = buildIntExpectedRowSet(totalCount);

    opTestBuilder()
      .physicalOperator(hashAggregate)
      .combineOutputBatches()
      .unordered()
      .addUpstreamBatch(inputRowSetBatch)
      .addExpectedResult(expectedRowSet)
      .go();
  }

  private HashAggregate createHashAggPhysicalOperator(AggPrelBase.OperatorPhase phase) {
    final List<NamedExpression> keyExpressions = Lists.newArrayList(
      new NamedExpression(SchemaPath.getSimplePath(FIRST_NAME_COL), new FieldReference(FIRST_NAME_COL)),
      new NamedExpression(SchemaPath.getSimplePath(LAST_NAME_COL), new FieldReference(LAST_NAME_COL)));

    final List<NamedExpression> aggExpressions = Lists.newArrayList(
      new NamedExpression(
        new FunctionCall("sum", ImmutableList.of(SchemaPath.getSimplePath(STUFF_COL)),
          new ExpressionPosition(null, 0)),
        new FieldReference(TOTAL_STUFF_COL)));

    return new HashAggregate(
      null,
      phase,
      keyExpressions,
      aggExpressions,
      0.0f);
  }

  private TupleMetadata buildInputSchema(TypeProtos.MinorType minorType, TypeProtos.DataMode dataMode) {
    return new SchemaBuilder()
      .add(FIRST_NAME_COL, TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED)
      .add(LAST_NAME_COL, TypeProtos.MinorType.VARCHAR, TypeProtos.DataMode.REQUIRED)
      .add(STUFF_COL, minorType, dataMode)
      .buildSchema();
  }

  private List<RowSet> buildInputRowSets(final TypeProtos.MinorType minorType,
                                  final TypeProtos.DataMode dataMode,
                                  final int dataCount,
                                  final int maxBatchSize) {
    Preconditions.checkArgument(dataCount > 0);
    Preconditions.checkArgument(maxBatchSize > 0);

    List<RowSet> inputRowSets = new ArrayList<>();
    int currentBatchSize = 0;
    RowSetBuilder inputRowSetBuilder = null;

    for (int multiplier = 1, firstNameIndex = 0; firstNameIndex < FIRST_NAMES.size(); firstNameIndex++) {
      final String firstName = FIRST_NAMES.get(firstNameIndex);

      for (int lastNameIndex = 0; lastNameIndex < LAST_NAMES.size(); lastNameIndex++, multiplier++) {
        final String lastName = LAST_NAMES.get(lastNameIndex);

        for (int index = 1; index <= dataCount; index++) {
          final int num = index * multiplier;

          if (currentBatchSize == 0) {
            final TupleMetadata inputSchema = buildInputSchema(minorType, dataMode);
            inputRowSetBuilder = new RowSetBuilder(operatorFixture.allocator(), inputSchema);
          }

          inputRowSetBuilder.addRow(firstName, lastName, num);
          currentBatchSize++;

          if (currentBatchSize == maxBatchSize) {
            final RowSet rowSet = inputRowSetBuilder.build();
            inputRowSets.add(rowSet);
            currentBatchSize = 0;
          }
        }
      }
    }

    if (currentBatchSize != 0) {
      inputRowSets.add(inputRowSetBuilder.build());
    }

    return inputRowSets;
  }

  private RowSet buildIntExpectedRowSet(final int dataCount) {
    final RowSetBuilder expectedRowSetBuilder = new RowSetBuilder(operatorFixture.allocator(), INT_OUTPUT_SCHEMA);

    for (int multiplier = 1, firstNameIndex = 0; firstNameIndex < FIRST_NAMES.size(); firstNameIndex++) {
      final String firstName = FIRST_NAMES.get(firstNameIndex);

      for (int lastNameIndex = 0; lastNameIndex < LAST_NAMES.size(); lastNameIndex++, multiplier++) {
        final String lastName = LAST_NAMES.get(lastNameIndex);
        final long total = ((dataCount * (dataCount + 1)) / 2) * multiplier;

        expectedRowSetBuilder.addRow(firstName, lastName, total);
      }
    }

    return expectedRowSetBuilder.build();
  }
}
