// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.impala.analysis;

import org.apache.impala.analysis.AnalysisContext.AnalysisResult;
import org.apache.impala.common.AnalysisException;
import org.apache.impala.common.ImpalaException;
import org.apache.impala.common.RuntimeEnv;
import org.apache.impala.rewrite.EqualityDisjunctsToInRule;
import org.apache.impala.rewrite.ExprRewriteRule;
import org.apache.impala.rewrite.ExprRewriter;
import org.apache.impala.thrift.TQueryOptions;
import org.junit.Assert;
import org.junit.Test;

import com.google.common.base.Preconditions;

import static org.apache.impala.analysis.ToSqlOptions.DEFAULT;
import static org.apache.impala.analysis.ToSqlOptions.REWRITTEN;
import static org.apache.impala.analysis.ToSqlOptions.SHOW_IMPLICIT_CASTS;

/**
 * Tests that the ExprRewriter framework covers all clauses as well as nested statements.
 * It also tests some specific rewrite rules.
 */
public class ExprRewriterTest extends AnalyzerTest {

  /**
   * Replaces any Expr that does not contain a Subquery with a TRUE BoolLiteral.
   */
  static class ExprToBoolRule implements ExprRewriteRule {
    public static ExprToBoolRule INSTANCE = new ExprToBoolRule();

    @Override
    public Expr apply(Expr expr, Analyzer analyzer) throws AnalysisException {
      if (expr.contains(Subquery.class)) return expr;
      if (Expr.IS_TRUE_LITERAL.apply(expr)) return expr;
      return new BoolLiteral(true);
    }

    private ExprToBoolRule() {}
  }

  /**
   * Replaces a TRUE BoolLiteral with a FALSE BoolLiteral.
   */
  static class TrueToFalseRule implements ExprRewriteRule {
    public static TrueToFalseRule INSTANCE = new TrueToFalseRule();

    @Override
    public Expr apply(Expr expr, Analyzer analyzer) throws AnalysisException {
      if (Expr.IS_TRUE_LITERAL.apply(expr)) return new BoolLiteral(false);
      return expr;
    }
    private TrueToFalseRule() {}
  }

  private final ExprRewriter exprToTrue_ = new ExprRewriter(ExprToBoolRule.INSTANCE);
  private final ExprRewriter trueToFalse_ = new ExprRewriter(TrueToFalseRule.INSTANCE);

  /**
   * Analyzes 'stmt' and rewrites Exprs with 'exprToTrue_' and validates the following:
   * 1. The actual number of changed Exprs should be equal to 'expectedNumChanges'.
   * 2. Checks that the Exprs were actually rewritten by doing another round of
   *    rewriting using 'trueToFalse_'. The expected number of changes for this
   *    second rewriting is 'expectedNumExprTrees'.
   * Does not use an AnalysisContext to avoid rewriting subqueries which might alter the
   * number of expressions and complicate validation.
   */
  public void RewritesOk(String stmt, int expectedNumChanges,
      int expectedNumExprTrees) throws ImpalaException {
    // Analyze without rewrites since that's what we want to test here.
    StatementBase parsedStmt = (StatementBase) ParsesOk(stmt);
    AnalyzesOkNoRewrite(parsedStmt);
    exprToTrue_.reset();
    parsedStmt.rewriteExprs(exprToTrue_);
    Assert.assertEquals(expectedNumChanges, exprToTrue_.getNumChanges());

    // Verify that the Exprs were actually replaced.
    trueToFalse_.reset();
    parsedStmt.rewriteExprs(trueToFalse_);
    Assert.assertEquals(expectedNumExprTrees, trueToFalse_.getNumChanges());

    // Make sure the stmt can be successfully re-analyzed.
    parsedStmt.reset();
    AnalyzesOkNoRewrite(parsedStmt);
  }

  /**
   * Asserts that no rewrites are performed on the given stmt.
   */
  public void CheckNoRewrite(String stmt) throws ImpalaException {
    exprToTrue_.reset();
    AnalysisContext analysisCtx = createAnalysisCtx();
    AnalysisResult result = parseAndAnalyze(stmt, analysisCtx);
    Preconditions.checkNotNull(result.getStmt());
    Assert.assertEquals(0, exprToTrue_.getNumChanges());
  }

  // Select statement with all clauses that has 11 rewritable Expr trees.
  // We expect a total of 23 exprs to be changed.
  private final String stmt_ =
      "select a.int_col a, 10 b, 20.2 c, count(b.int_col) cnt from " +
      "functional.alltypes a join functional.alltypes b on (a.id = b.id)" +
      "where b.float_col > 1 and b.double_col > 2 " +
      "group by 1, a.string_col " +
      "having count(b.int_col) < 3 " +
      "order by a.int_col, 4 limit 10";

  @Test
  public void TestQueryStmts() throws ImpalaException {
    RewritesOk(stmt_, 23, 11);
    // Test rewriting in inline views. The view stmt is the same as the query above
    // but with an order by + limit. Expanded star exprs are not rewritten.
    RewritesOk("select * from (" + stmt_ + ") v", 23, 11);
    // Test union, 11 + 11 + 1 rewritable Expr trees.
    RewritesOk(String.format("%s union all (%s) order by cnt", stmt_, stmt_), 47, 23);
    // Test union inside an inline view.
    RewritesOk(String.format("select * from (%s union all (%s) order by cnt limit 10) v",
        stmt_, stmt_), 47, 23);
    // Constant select.
    RewritesOk("select 1, 2, 3, 4", 4, 4);
    // Values stmt.
    RewritesOk("values(1, '2', 3, 4.1), (1, '2', 3, 4.1)", 8, 8);
    // Test WHERE-clause subqueries.
    RewritesOk("select id, int_col from functional.alltypes a " +
        "where exists (select 1 from functional.alltypes " +
        "where string_col = 'test' having count(*) < 10)", 9, 5);
    RewritesOk("select id, int_col from functional.alltypes a " +
        "where a.id in (select count(*) from functional.alltypes " +
        "where string_col = 'test' having count(*) < 10)", 10, 6);
  }

  @Test
  public void TestDdlStmts() throws ImpalaException {
    RewritesOk("create table ctas_test as " + stmt_, 23, 11);
    // Create/alter view stmts are not rewritten to preserve the original SQL.
    CheckNoRewrite("create view view_test as " + stmt_);
    CheckNoRewrite("alter view functional.alltypes_view as " + stmt_);
  }

  @Test
  public void TestDmlStmts() throws ImpalaException {
    // Insert.
    RewritesOk("insert into functional.alltypes (id, int_col, float_col, bigint_col) " +
      "partition(year=2009,month=10) " + stmt_, 23, 11);

    if (RuntimeEnv.INSTANCE.isKuduSupported()) {
      // Update.
      RewritesOk("update t2 set name = 'test' from " +
          "functional.alltypes t1 join functional_kudu.dimtbl t2 on (t1.id = t2.id) " +
          "where t2.id < 10", 10, 5);
      RewritesOk("update functional_kudu.dimtbl set name = 'test', zip = 4711 " +
          "where exists (" + stmt_ + ")", 28, 16);
      // Delete.
      RewritesOk("delete a from " +
          "functional_kudu.testtbl a join functional.testtbl b on a.zip = b.zip", 4, 2);
      RewritesOk("delete functional_kudu.testtbl where exists (" + stmt_ + ")", 24, 12);
    }
  }

  /**
   * construct an in-list: string_col in [offset ... offset + length)
   */
  private void CreateInList(int offset, int length, StringBuilder stmtSb) {
    stmtSb.append("string_col in(");
    for (int j = 0; j < length - 1; ++j) {
      stmtSb.append("'c").append(offset + j).append("',");
    }
    stmtSb.append("'c").append(offset + length - 1).append("')");
  }

  private void CheckNumChangesByEqualityDisjunctsToInRule(
      String stmt, int expectedNumChanges) throws ImpalaException {
    StatementBase parsedStmt = (StatementBase) ParsesOk(stmt);
    AnalyzesOkNoRewrite(parsedStmt);
    ExprRewriter rewriter = new ExprRewriter(EqualityDisjunctsToInRule.INSTANCE);
    parsedStmt.rewriteExprs(rewriter);
    Assert.assertEquals(expectedNumChanges, rewriter.getNumChanges());
  }

  @Test
  public void TestEqualityDisjunctsToInRuleSizeLimit() throws ImpalaException {
    String stmtPrefix = "select count(*) from functional.alltypes where ( ";
    // Test that EqualityDisjunctsToInRule doesn't create an expr with a number of
    // children exceeding the limit.
    {
      // Create a disjunct with 2 in-lists of length Expr.EXPR_CHILDREN_LIMIT - 1.
      StringBuilder stmtSb = new StringBuilder(stmtPrefix);
      for (int i = 0; i < 2; ++i) {
        int offset = (Expr.EXPR_CHILDREN_LIMIT - 1) * i;
        CreateInList(offset, Expr.EXPR_CHILDREN_LIMIT - 1, stmtSb);
        if (i != 1) stmtSb.append(" or ");
      }
      stmtSb.append(")");
      CheckNumChangesByEqualityDisjunctsToInRule(stmtSb.toString(), 0);
    }
    {
      // Create a disjunct with an in-list of length Expr.EXPR_CHILDREN_LIMIT - 1 and a
      // EQ predicate.
      StringBuilder stmtSb = new StringBuilder(stmtPrefix);
      CreateInList(0, Expr.EXPR_CHILDREN_LIMIT - 1, stmtSb);
      stmtSb.append("or string_col='").append(Expr.EXPR_CHILDREN_LIMIT - 1).append("')");
      CheckNumChangesByEqualityDisjunctsToInRule(stmtSb.toString(), 0);
    }
    {
      // Create a disjunct with an in-list of length Expr.EXPR_CHILDREN_LIMIT - 2 and 2
      // EQ predicates.
      StringBuilder stmtSb = new StringBuilder(stmtPrefix);
      CreateInList(0, Expr.EXPR_CHILDREN_LIMIT - 2, stmtSb);
      stmtSb.append("or string_col='").append(Expr.EXPR_CHILDREN_LIMIT - 2)
          .append("' or string_col='").append(Expr.EXPR_CHILDREN_LIMIT - 1).append("')");
      CheckNumChangesByEqualityDisjunctsToInRule(stmtSb.toString(), 1);
    }
  }

  @Test
  public void TestToSql() {
    TQueryOptions options = new TQueryOptions();
    options.setEnable_expr_rewrites(true);
    AnalysisContext ctx = createAnalysisCtx(options);

    //----------------------
    // Test query rewrites.
    //----------------------
    assertToSql(ctx, "select 1 + 1", "SELECT 1 + 1", "SELECT 2");

    assertToSql(ctx,
        "select (case when true then 1 else id end) from functional.alltypes " +
        "union " +
        "select 1 + 1",
        "SELECT (CASE WHEN TRUE THEN 1 ELSE id END) FROM functional.alltypes " +
        "UNION " +
        "SELECT 1 + 1",
        "SELECT 1 FROM functional.alltypes UNION SELECT 2");

    assertToSql(ctx,
        "values(1, '2', 3, 4.1), (1, '2', 3, 4.1)",
        "VALUES((1, '2', 3, 4.1), (1, '2', 3, 4.1))",
        "SELECT 1, '2', 3, 4.1 UNION ALL SELECT 1, '2', 3, 4.1");

    assertToSql(ctx,
        "select case when 1 = 1 then 1 else 2.0 end from functional.alltypes",
        "SELECT CASE WHEN 1 = 1 THEN 1 ELSE 2.0 END FROM functional.alltypes",
        "SELECT 1.0 FROM functional.alltypes");

    assertToSql(ctx,
        "select case when false then 1.0 else 2 end from functional.alltypes",
        "SELECT CASE WHEN FALSE THEN 1.0 ELSE 2 END FROM functional.alltypes",
        "SELECT 2.0 FROM functional.alltypes");

    assertToSql(ctx,
        "select * from functional.alltypes where case " +
        "when true = true then year < 2019 " +
        "when false then year > 2010 end",
        "SELECT * FROM functional.alltypes WHERE CASE " +
        "WHEN TRUE = TRUE THEN `year` < 2019 " +
        "WHEN FALSE THEN `year` > 2010 END",
        "SELECT * FROM functional.alltypes WHERE `year` < 2019");

    //-------------------------
    // Test subquery rewrites.
    //-------------------------
    assertToSql(ctx, "select * from (" +
        "select * from functional.alltypes where id = (select 1 + 1)) a",
        "SELECT * FROM (SELECT * FROM functional.alltypes WHERE id = (SELECT 1 + 1)) a",
        "SELECT * FROM (SELECT * FROM functional.alltypes LEFT SEMI JOIN " +
        "(SELECT 2) `$a$1` (`$c$1`) ON id = `$a$1`.`$c$1`) a");

    assertToSql(ctx,
        "select * from (select * from functional.alltypes where id = (select 1 + 1)) a " +
        "union " +
        "select * from (select * from functional.alltypes where id = (select 1 + 1)) b",
        "SELECT * FROM (SELECT * FROM functional.alltypes WHERE id = (SELECT 1 + 1)) a " +
        "UNION " +
        "SELECT * FROM (SELECT * FROM functional.alltypes WHERE id = (SELECT 1 + 1)) b",
        "SELECT * FROM (SELECT * FROM functional.alltypes LEFT SEMI JOIN (SELECT 2) " +
        "`$a$1` (`$c$1`) ON id = `$a$1`.`$c$1`) a " +
        "UNION " +
        "SELECT * FROM (SELECT * FROM functional.alltypes LEFT SEMI JOIN (SELECT 2) " +
        "`$a$1` (`$c$1`) ON id = `$a$1`.`$c$1`) b");

    assertToSql(ctx, "select * from " +
        "(select (case when true then 1 else id end) from functional.alltypes " +
        "union select 1 + 1) v",
        "SELECT * FROM (SELECT (CASE WHEN TRUE THEN 1 ELSE id END) " +
        "FROM functional.alltypes UNION SELECT 1 + 1) v",
        "SELECT * FROM (SELECT 1 FROM functional.alltypes " +
        "UNION SELECT 2) v");

    //---------------------
    // Test CTAS rewrites.
    //---------------------
    assertToSql(ctx,
        "create table ctas_test as select 1 + 1",
        "CREATE TABLE default.ctas_test\n" +
        "STORED AS TEXTFILE\n" +
        " AS SELECT 1 + 1",
        "CREATE TABLE default.ctas_test\n" +
        "STORED AS TEXTFILE\n" +
        " AS SELECT 2");

    //--------------------
    // Test DML rewrites.
    //--------------------
    // Insert
    assertToSql(ctx,
        "insert into functional.alltypes(id) partition(year=2009, month=10) " +
        "select 1 + 1",
        "INSERT INTO TABLE functional.alltypes(id) " +
        "PARTITION (`year`=2009, `month`=10) SELECT 1 + 1",
        "INSERT INTO TABLE functional.alltypes(id) " +
        "PARTITION (`year`=2009, `month`=10) SELECT 2");

    if (RuntimeEnv.INSTANCE.isKuduSupported()) {
      // Update.
      assertToSql(ctx,
          "update functional_kudu.alltypes "
              + "set string_col = 'test' where id = (select 1 + 1)",
          "UPDATE functional_kudu.alltypes SET string_col = 'test' "
              + "FROM functional_kudu.alltypes WHERE id = (SELECT 1 + 1)",
          "UPDATE functional_kudu.alltypes SET string_col = 'test' "
              + "FROM functional_kudu.alltypes LEFT SEMI JOIN (SELECT 2) `$a$1` (`$c$1`) "
              + "ON id = `$a$1`.`$c$1` WHERE id = (SELECT 2)");

      // Delete
      assertToSql(ctx,
          "delete functional_kudu.alltypes "
              + "where id = (select 1 + 1)",
          "DELETE FROM functional_kudu.alltypes "
              + "WHERE id = (SELECT 1 + 1)",
          "DELETE functional_kudu.alltypes "
              + "FROM functional_kudu.alltypes LEFT SEMI JOIN (SELECT 2) `$a$1` (`$c$1`) "
              + "ON id = `$a$1`.`$c$1` WHERE id = (SELECT 2)");
    }

    // We don't do any rewrite for WITH clause.
    StatementBase stmt = (StatementBase) AnalyzesOk("with t as (select 1 + 1) " +
        "select id from functional.alltypes union select id from functional.alltypesagg",
        ctx);
    Assert.assertEquals(stmt.toSql(), stmt.toSql());
  }

  @Test
  /**
   * Test printing of implicit casts
   */
  public void TestToSqlWithImplicitCasts() {
    TQueryOptions options = new TQueryOptions();
    options.setEnable_expr_rewrites(true);
    AnalysisContext ctx = createAnalysisCtx(options);

    assertToSqlWithImplicitCasts(ctx,
        "select * from functional_kudu.alltypestiny where bigint_col < "
            + "1000 / 100",
        "SELECT * FROM functional_kudu.alltypestiny WHERE "
            + "CAST(bigint_col AS DOUBLE) < CAST(10 AS DOUBLE)");
    assertToSqlWithImplicitCasts(ctx,
        "select float_col + 1.1 from functional.alltypestiny",
        "SELECT CAST(float_col AS DECIMAL(38,9)) + CAST(1.1 AS DECIMAL(2,1)) "
            + "FROM functional.alltypestiny");
    assertToSqlWithImplicitCasts(
        ctx, "select cast(2 as bigint)", "SELECT CAST(2 AS BIGINT)");
    assertToSqlWithImplicitCasts(ctx, "select cast(2 as decimal(38,37))",
        "SELECT CAST(2.0000000000000000000000000000000000000 AS DECIMAL(38,37))");
    assertToSqlWithImplicitCasts(ctx, "select d1 - 1.1 from functional.decimal_tbl",
        "SELECT d1 - CAST(1.1 AS DECIMAL(2,1)) FROM functional.decimal_tbl");
    assertToSqlWithImplicitCasts(ctx, "select * from functional.date_tbl "
        + "where date_col = '2017-11-28'",
        "SELECT * FROM functional.date_tbl "
        + "WHERE date_col = DATE '2017-11-28'");
    assertToSqlWithImplicitCasts(ctx, "select * from functional.alltypes, "
        + "functional.date_tbl where timestamp_col = date_col",
        "SELECT * FROM functional.alltypes, functional.date_tbl "
        + "WHERE timestamp_col = CAST(date_col AS TIMESTAMP)");
    assertToSqlWithImplicitCasts(ctx, "select round(1.2345, 2) * pow(10, 10)",
        "SELECT CAST(12300000000 AS DOUBLE)");
    assertToSqlWithImplicitCasts(ctx,
        "select * from functional.alltypes where "
            + "double_col in (int_col, bigint_col)",
        "SELECT * FROM functional.alltypes WHERE double_col IN "
            + "(CAST(int_col AS DOUBLE), CAST(bigint_col AS DOUBLE))");
    assertToSqlWithImplicitCasts(ctx,
        "select * from functional.alltypes "
            + "where double_col between smallint_col and int_col",
        "SELECT * FROM functional.alltypes WHERE double_col >= "
            + "CAST(smallint_col AS DOUBLE) AND double_col <= CAST(int_col AS DOUBLE)");
    assertToSqlWithImplicitCasts(ctx,
        "select * from "
            + "(select 10 as i, 2 as j, 2013 as s) as t "
            + "where t.i < 10",
        "SELECT * FROM "
            + "(SELECT CAST(10 AS TINYINT) i, CAST(2 AS TINYINT) j, "
            + "CAST(2013 AS SMALLINT) s) t "
            + "WHERE t.i < CAST(10 AS TINYINT)");
    assertToSqlWithImplicitCasts(ctx,
        "select * from (select id, int_col, year,  sum(int_col) "
            + " over(partition by year order by id) as s from functional.alltypes) v "
            + " where year = 2009 and id = 1 and"
            + " int_col < 10 and s = 4",
        "SELECT * FROM (SELECT id, int_col, `year`, sum(int_col)"
            + " OVER (PARTITION BY `year` ORDER BY id ASC) s FROM functional.alltypes) v"
            + " WHERE `year` = CAST(2009 AS INT) AND id = CAST(1 AS INT) AND"
            + " int_col < CAST(10 AS INT) AND s = CAST(4 AS BIGINT)");
    assertToSqlWithImplicitCasts(ctx,
        "select * from functional.alltypes where "
            + "int_col = 1 or int_col = 2 "
            + "or tinyint_col > 5 AND "
            + "(float_col = 5 or double_col = 6)",
        "SELECT * FROM functional.alltypes WHERE "
            + "int_col IN (CAST(1 AS INT), CAST(2 AS INT)) "
            + "OR tinyint_col > CAST(5 AS TINYINT) AND "
            + "(float_col = CAST(5 AS FLOAT) OR double_col = CAST(6 AS DOUBLE))");

    checkNumericLiteralCasts(ctx, "tinyint_col", "1", "TINYINT");
    checkNumericLiteralCasts(ctx, "smallint_col", "1", "TINYINT");
    checkNumericLiteralCasts(ctx, "smallint_col", "1000", "SMALLINT");
    checkNumericLiteralCasts(ctx, "int_col", "1", "TINYINT");
    checkNumericLiteralCasts(ctx, "int_col", "1000", "SMALLINT");
    checkNumericLiteralCasts(ctx, "int_col", "1000000", "INT");
    checkNumericLiteralCasts(ctx, "bigint_col", "1", "TINYINT");
    checkNumericLiteralCasts(ctx, "bigint_col", "1000", "SMALLINT");
    checkNumericLiteralCasts(ctx, "bigint_col", "1000000", "INT");
    checkNumericLiteralCasts(ctx, "bigint_col", "10000000000", "BIGINT");
    checkNumericLiteralCasts(ctx, "float_col", "1", "TINYINT");
    checkNumericLiteralCasts(ctx, "float_col", "1.0", "DECIMAL(2,1)");
    checkNumericLiteralCasts(ctx, "float_col", "100000.001", "DECIMAL(9,3)");
    checkNumericLiteralCasts(ctx, "double_col", "1", "TINYINT");
    checkNumericLiteralCasts(ctx, "double_col", "1.0", "DECIMAL(2,1)");
    checkNumericLiteralCasts(ctx, "double_col", "100000.001", "DECIMAL(9,3)");
  }

  /**
   * Generate an insert query into a column and check that the toSql() with implicit casts
   * looks as expected.
   * columnName is the name of a column in functional.alltypesnopart.
   * data is the literal value to insert.
   * castColumn is the type to which the literal is expected to be cast.
   */
  private void checkNumericLiteralCasts(
      AnalysisContext ctx, String columnName, String data, String castColumn) {
    String query = "insert into table functional.alltypesnopart (" + columnName + ") "
        + "values(" + data + ")";
    String expectedToSql = "INSERT INTO TABLE "
        + "functional.alltypesnopart(" + columnName + ") "
        + "SELECT CAST(" + data + " AS " + castColumn + ")"
        + " UNION "
        + "SELECT CAST(" + data + " AS " + castColumn + ")";
    assertToSqlWithImplicitCasts(ctx, query, expectedToSql);
  }

  private void assertToSql(AnalysisContext ctx, String query, String expectedToSql,
      String expectedToRewrittenSql) {
    StatementBase stmt = (StatementBase) AnalyzesOk(query, ctx);
    Assert.assertEquals(expectedToSql, stmt.toSql(DEFAULT));
    Assert.assertEquals(expectedToSql, stmt.toSql());
    Assert.assertEquals(expectedToRewrittenSql, stmt.toSql(REWRITTEN));
  }

  private void assertToSqlWithImplicitCasts(
      AnalysisContext ctx, String query, String expectedToSqlWithImplicitCasts) {
    StatementBase stmt = (StatementBase) AnalyzesOk(query, ctx);
    String actual = stmt.toSql(SHOW_IMPLICIT_CASTS);
    Assert.assertEquals("Bad sql with implicit casts from original query:\n" + query,
        expectedToSqlWithImplicitCasts, actual);
  }
}
