/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 **/

#include "query_optimizer/rules/ExtractCommonSubexpression.hpp"

#include <cstddef>
#include <memory>
#include <unordered_set>
#include <vector>

#include "query_optimizer/OptimizerContext.hpp"
#include "query_optimizer/expressions/AggregateFunction.hpp"
#include "query_optimizer/expressions/Alias.hpp"
#include "query_optimizer/expressions/CommonSubexpression.hpp"
#include "query_optimizer/expressions/ExpressionType.hpp"
#include "query_optimizer/expressions/NamedExpression.hpp"
#include "query_optimizer/expressions/PatternMatcher.hpp"
#include "query_optimizer/expressions/Scalar.hpp"
#include "query_optimizer/physical/Aggregate.hpp"
#include "query_optimizer/physical/HashJoin.hpp"
#include "query_optimizer/physical/NestedLoopsJoin.hpp"
#include "query_optimizer/physical/Physical.hpp"
#include "query_optimizer/physical/PhysicalType.hpp"
#include "query_optimizer/physical/Selection.hpp"
#include "utility/HashError.hpp"

#include "glog/logging.h"

namespace quickstep {
namespace optimizer {

namespace E = ::quickstep::optimizer::expressions;
namespace P = ::quickstep::optimizer::physical;

ExtractCommonSubexpression::ExtractCommonSubexpression(
    OptimizerContext *optimizer_context)
    : optimizer_context_(optimizer_context) {
  const std::vector<E::ExpressionType> homogeneous_expr_types = {
      E::ExpressionType::kAlias,
      E::ExpressionType::kAttributeReference,
      E::ExpressionType::kBinaryExpression,
      E::ExpressionType::kCast,
      E::ExpressionType::kCommonSubexpression,
      E::ExpressionType::kScalarLiteral,
      E::ExpressionType::kUnaryExpression
  };

  for (const auto &expr_type : homogeneous_expr_types) {
    homogeneous_expression_types_.emplace(expr_type);
  }
}

P::PhysicalPtr ExtractCommonSubexpression::applyToNode(
    const P::PhysicalPtr &input) {
  switch (input->getPhysicalType()) {
    case P::PhysicalType::kAggregate: {
      const P::AggregatePtr aggregate =
          std::static_pointer_cast<const P::Aggregate>(input);

      std::vector<E::ExpressionPtr> expressions;
      // Gather grouping expressions and aggregate functions' argument expressions.
      for (const auto &expr : aggregate->grouping_expressions()) {
        expressions.emplace_back(expr);
      }
      for (const auto &expr : aggregate->aggregate_expressions()) {
        const E::AggregateFunctionPtr &func =
            std::static_pointer_cast<const E::AggregateFunction>(expr->expression());
        for (const auto &arg : func->getArguments()) {
          expressions.emplace_back(arg);
        }
      }

      // Transform the expressions so that common subexpressions are labelled.
      const std::vector<E::ExpressionPtr> new_expressions =
          transformExpressions(expressions);

      if (new_expressions != expressions) {
        std::vector<E::AliasPtr> new_aggregate_expressions;
        std::vector<E::NamedExpressionPtr> new_grouping_expressions;

        // Reconstruct grouping expressions.
        const std::size_t num_grouping_expressions =
            aggregate->grouping_expressions().size();
        for (std::size_t i = 0; i < num_grouping_expressions; ++i) {
          DCHECK(E::SomeNamedExpression::Matches(new_expressions[i]));
          new_grouping_expressions.emplace_back(
              std::static_pointer_cast<const E::NamedExpression>(new_expressions[i]));
        }

        // Reconstruct aggregate expressions.
        auto it = new_expressions.begin() + num_grouping_expressions;
        for (const auto &expr : aggregate->aggregate_expressions()) {
          const E::AggregateFunctionPtr &func =
              std::static_pointer_cast<const E::AggregateFunction>(
                  expr->expression());

          std::vector<E::ScalarPtr> new_arguments;
          for (std::size_t i = 0; i < func->getArguments().size(); ++i, ++it) {
            DCHECK(E::SomeScalar::Matches(*it));
            new_arguments.emplace_back(std::static_pointer_cast<const E::Scalar>(*it));
          }

          if (new_arguments == func->getArguments()) {
            new_aggregate_expressions.emplace_back(expr);
          } else {
            const E::AggregateFunctionPtr new_func =
                E::AggregateFunction::Create(func->getAggregate(),
                                             new_arguments,
                                             func->is_vector_aggregate(),
                                             func->is_distinct());
            new_aggregate_expressions.emplace_back(
                E::Alias::Create(expr->id(),
                                 new_func,
                                 expr->attribute_name(),
                                 expr->attribute_alias(),
                                 expr->relation_name()));
          }
        }
        return P::Aggregate::Create(aggregate->input(),
                                    new_grouping_expressions,
                                    new_aggregate_expressions,
                                    aggregate->filter_predicate());
      }
      break;
    }
    case P::PhysicalType::kSelection: {
      const P::SelectionPtr selection =
          std::static_pointer_cast<const P::Selection>(input);

      // Transform Selection's project expressions.
      const std::vector<E::NamedExpressionPtr> new_expressions =
          DownCast<E::NamedExpression>(
              transformExpressions(UpCast(selection->project_expressions())));

      if (new_expressions != selection->project_expressions()) {
        return P::Selection::Create(selection->input(),
                                    new_expressions,
                                    selection->filter_predicate(),
                                    selection->input()->cloneOutputPartitionSchemeHeader());
      }
      break;
    }
    case P::PhysicalType::kHashJoin: {
      const P::HashJoinPtr hash_join =
          std::static_pointer_cast<const P::HashJoin>(input);

      // Transform HashJoin's project expressions.
      const std::vector<E::NamedExpressionPtr> new_expressions =
          DownCast<E::NamedExpression>(
              transformExpressions(UpCast(hash_join->project_expressions())));

      if (new_expressions != hash_join->project_expressions()) {
        return P::HashJoin::Create(hash_join->left(),
                                   hash_join->right(),
                                   hash_join->left_join_attributes(),
                                   hash_join->right_join_attributes(),
                                   hash_join->residual_predicate(),
                                   new_expressions,
                                   hash_join->join_type());
      }
      break;
    }
    case P::PhysicalType::kNestedLoopsJoin: {
      const P::NestedLoopsJoinPtr nested_loops_join =
          std::static_pointer_cast<const P::NestedLoopsJoin>(input);

      // Transform NestedLoopsJoin's project expressions.
      const std::vector<E::NamedExpressionPtr> new_expressions =
          DownCast<E::NamedExpression>(
              transformExpressions(UpCast(nested_loops_join->project_expressions())));

      if (new_expressions != nested_loops_join->project_expressions()) {
        return P::NestedLoopsJoin::Create(nested_loops_join->left(),
                                          nested_loops_join->right(),
                                          nested_loops_join->join_predicate(),
                                          new_expressions);
      }
      break;
    }
    default:
      break;
  }

  return input;
}

std::vector<E::ExpressionPtr> ExtractCommonSubexpression::transformExpressions(
    const std::vector<E::ExpressionPtr> &expressions) {
  // Step 1. For each subexpression, count the number of its occurrences.
  ScalarCounter counter;
  ScalarHashable hashable;
  for (const auto &expr : expressions) {
    visitAndCount(expr, &counter, &hashable);
  }

  // Note that any subexpression with count > 1 is a common subexpression.
  // However, it is not necessary to create a CommonSubexpression node for every
  // such subexpression. E.g. consider the case
  // --
  //   SELECT (x+1)*(x+2), (x+1)*(x+2)*3 FROM s;
  // --
  // We only need to create one *dominant* CommonSubexpression (x+1)*(x+2) and
  // do not need to create the child (x+1) or (x+2) nodes.
  //
  // To address this issue. We define that a subtree S *dominates* its descendent
  // subtree (or leaf node) T if and only if counter[S] >= counter[T].
  //
  // Then we create CommonSubexpression nodes for every subexpression that is
  // not dominated by any ancestor.

  ScalarMap substitution_map;
  std::vector<E::ExpressionPtr> new_expressions;
  for (const auto &expr : expressions) {
    new_expressions.emplace_back(
        visitAndTransform(expr, 1, counter, hashable, &substitution_map));
  }
  return new_expressions;
}

E::ExpressionPtr ExtractCommonSubexpression::transformExpression(
    const E::ExpressionPtr &expression) {
  return transformExpressions({expression}).front();
}

bool ExtractCommonSubexpression::visitAndCount(
    const E::ExpressionPtr &expression,
    ScalarCounter *counter,
    ScalarHashable *hashable) const {
  // This bool flag is for avoiding some unnecessary hash() computation.
  bool children_hashable = true;

  const auto homogeneous_expression_types_it =
      homogeneous_expression_types_.find(expression->getExpressionType());
  if (homogeneous_expression_types_it != homogeneous_expression_types_.end()) {
    for (const auto &child : expression->children()) {
      children_hashable &= visitAndCount(child, counter, hashable);
    }
  }

  E::ScalarPtr scalar;
  if (children_hashable &&
      E::SomeScalar::MatchesWithConditionalCast(expression, &scalar)) {
    // A scalar expression may or may not support the hash() computation.
    // In the later case, a HashNotSupported exception will be thrown and we
    // simply ignore this expression (and all its ancestor expressions).
    try {
      ++(*counter)[scalar];
    } catch (const HashNotSupported &e) {
      return false;
    }
    hashable->emplace(scalar);
    return true;
  }
  return false;
}

E::ExpressionPtr ExtractCommonSubexpression::visitAndTransform(
    const E::ExpressionPtr &expression,
    const std::size_t max_reference_count,
    const ScalarCounter &counter,
    const ScalarHashable &hashable,
    ScalarMap *substitution_map) {
  // TODO(jianqiao): Currently it is hardly beneficial to make AttributeReference
  // a common subexpression due to the inefficiency of ScalarAttribute's
  // size-not-known-at-compile-time std::memcpy() calls, compared to copy-elision
  // at selection level. Even in the case of compressed column store, it requires
  // an attribute to occur at least 4 times for the common subexpression version
  // to outperform the direct decoding version. We may look into ScalarAttribute
  // and modify the heuristic here later.
  if (expression->getExpressionType() == E::ExpressionType::kScalarLiteral ||
      expression->getExpressionType() == E::ExpressionType::kAttributeReference) {
    return expression;
  }

  E::ScalarPtr scalar;
  const bool is_hashable =
      E::SomeScalar::MatchesWithConditionalCast(expression, &scalar)
          && hashable.find(scalar) != hashable.end();

  std::size_t new_max_reference_count;
  if (is_hashable) {
    // CommonSubexpression node already generated somewhere. Just refer to that
    // and return.
    const auto substitution_map_it = substitution_map->find(scalar);
    if (substitution_map_it != substitution_map->end()) {
      return substitution_map_it->second;
    }

    // Otherwise, update the dominance count.
    const auto counter_it = counter.find(scalar);
    DCHECK(counter_it != counter.end());
    DCHECK_LE(max_reference_count, counter_it->second);
    new_max_reference_count = counter_it->second;
  } else {
    new_max_reference_count = max_reference_count;
  }

  // Process children.
  std::vector<E::ExpressionPtr> new_children;
  const auto homogeneous_expression_types_it =
      homogeneous_expression_types_.find(expression->getExpressionType());
  if (homogeneous_expression_types_it == homogeneous_expression_types_.end()) {
    // If child subexpressions cannot be hoisted through the current expression,
    // treat child expressions as isolated sub-optimizations.
    for (const auto &child : expression->children()) {
      new_children.emplace_back(transformExpression(child));
    }
  } else {
    for (const auto &child : expression->children()) {
      new_children.emplace_back(
          visitAndTransform(child,
                            new_max_reference_count,
                            counter,
                            hashable,
                            substitution_map));
    }
  }

  E::ExpressionPtr output;
  if (new_children == expression->children()) {
    output = expression;
  } else {
    output = std::static_pointer_cast<const E::Scalar>(
        expression->copyWithNewChildren(new_children));
  }

  // Wrap it with a new CommonSubexpression node if the current expression is a
  // dominant subexpression.
  if (is_hashable && new_max_reference_count > max_reference_count) {
    DCHECK(E::SomeScalar::Matches(output));
    const E::CommonSubexpressionPtr common_subexpression =
        E::CommonSubexpression::Create(
            optimizer_context_->nextExprId(),
            std::static_pointer_cast<const E::Scalar>(output));
    substitution_map->emplace(scalar, common_subexpression);
    output = common_subexpression;
  }

  return output;
}

}  // namespace optimizer
}  // namespace quickstep
