[NO ISSUE][COMP] Extension to InjectTypeCastForFunctionArgumentsRule

- user model changes: no
- storage format changes: no
- interface changes: no

Details:
Extend InjectTypeCastForFunctionArgumentsRule for functions
that need arguments casting.

Change-Id: I68c264e7885e4f7d51a90fc615891a832a69e785
Reviewed-on: https://asterix-gerrit.ics.uci.edu/c/asterixdb/+/4163
Tested-by: Jenkins <jenkins@fulliautomatix.ics.uci.edu>
Integration-Tests: Jenkins <jenkins@fulliautomatix.ics.uci.edu>
Reviewed-by: Ali Alsuliman <ali.al.solaiman@gmail.com>
diff --git a/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/InjectTypeCastForFunctionArgumentsRule.java b/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/InjectTypeCastForFunctionArgumentsRule.java
index 9531e94..6c8ad3a 100644
--- a/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/InjectTypeCastForFunctionArgumentsRule.java
+++ b/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/InjectTypeCastForFunctionArgumentsRule.java
@@ -24,15 +24,12 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.function.IntBinaryOperator;
-import java.util.function.IntPredicate;
 
 import org.apache.asterix.dataflow.data.common.TypeResolverUtil;
 import org.apache.asterix.lang.common.util.FunctionUtil;
 import org.apache.asterix.om.functions.BuiltinFunctions;
 import org.apache.asterix.om.typecomputer.base.TypeCastUtils;
 import org.apache.asterix.om.typecomputer.impl.TypeComputeUtils;
-import org.apache.asterix.om.types.ATypeTag;
 import org.apache.asterix.om.types.IAType;
 import org.apache.commons.lang3.mutable.Mutable;
 import org.apache.commons.lang3.mutable.MutableObject;
@@ -56,9 +53,10 @@
  */
 public class InjectTypeCastForFunctionArgumentsRule implements IAlgebraicRewriteRule {
 
-    private static final Map<FunctionIdentifier, IntPredicate> FUN_TO_ARG_CHECKER = new HashMap<>();
+    private static final Map<FunctionIdentifier, BiIntPredicate> FUN_TO_ARG_CHECKER = new HashMap<>();
 
     static {
+        addFunctionAndArgChecker(BuiltinFunctions.SWITCH_CASE, InjectTypeCastForFunctionArgumentsRule::switchResultArg);
         addFunctionAndArgChecker(BuiltinFunctions.IF_MISSING, null);
         addFunctionAndArgChecker(BuiltinFunctions.IF_NULL, null);
         addFunctionAndArgChecker(BuiltinFunctions.IF_MISSING_OR_NULL, null);
@@ -66,7 +64,7 @@
     }
 
     // allows the rule to check other functions in addition to the ones specified here
-    public static void addFunctionAndArgChecker(FunctionIdentifier function, IntPredicate argChecker) {
+    public static void addFunctionAndArgChecker(FunctionIdentifier function, BiIntPredicate argChecker) {
         FUN_TO_ARG_CHECKER.put(function, argChecker);
     }
 
@@ -89,7 +87,7 @@
 
     // Injects type casts to cast return expressions' return types to a generalized type that conforms to every
     // return type.
-    private boolean injectTypeCast(ILogicalOperator op, Mutable<ILogicalExpression> exprRef,
+    private static boolean injectTypeCast(ILogicalOperator op, Mutable<ILogicalExpression> exprRef,
             IOptimizationContext context) throws AlgebricksException {
         ILogicalExpression expr = exprRef.getValue();
         if (expr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
@@ -105,36 +103,29 @@
             }
         }
         FunctionIdentifier funcId = func.getFunctionIdentifier();
-        if (funcId.equals(BuiltinFunctions.SWITCH_CASE)) {
-            rewritten |= rewriteFunction(op, func, null, context, 2,
-                    InjectTypeCastForFunctionArgumentsRule::switchIncrement);
-        } else if (FUN_TO_ARG_CHECKER.containsKey(funcId)) {
-            rewritten |= rewriteFunction(op, func, FUN_TO_ARG_CHECKER.get(funcId), context, 0,
-                    InjectTypeCastForFunctionArgumentsRule::increment);
+        if (FUN_TO_ARG_CHECKER.containsKey(funcId)) {
+            rewritten |= rewriteFunction(op, func, FUN_TO_ARG_CHECKER.get(funcId), context);
         }
         return rewritten;
     }
 
     // Injects casts that cast types for all function parameters
-    private boolean rewriteFunction(ILogicalOperator op, AbstractFunctionCallExpression func, IntPredicate argChecker,
-            IOptimizationContext context, int argStartIdx, IntBinaryOperator increment) throws AlgebricksException {
+    private static boolean rewriteFunction(ILogicalOperator op, AbstractFunctionCallExpression func,
+            BiIntPredicate argChecker, IOptimizationContext context) throws AlgebricksException {
         IVariableTypeEnvironment env = op.computeInputTypeEnvironment(context);
         IAType producedType = (IAType) env.getType(func);
-        if (!argumentsNeedCasting(producedType)) {
-            return false;
-        }
         List<Mutable<ILogicalExpression>> argRefs = func.getArguments();
         int argSize = argRefs.size();
         boolean rewritten = false;
-        for (int argIndex = argStartIdx; argIndex < argSize; argIndex += increment.applyAsInt(argIndex, argSize)) {
-            if (argChecker == null || argChecker.test(argIndex)) {
+        for (int argIndex = 0; argIndex < argSize; argIndex++) {
+            if (argChecker == null || argChecker.test(argIndex, argSize)) {
                 rewritten |= rewriteFunctionArgument(argRefs.get(argIndex), producedType, env);
             }
         }
         return rewritten;
     }
 
-    private boolean rewriteFunctionArgument(Mutable<ILogicalExpression> argRef, IAType funcOutputType,
+    private static boolean rewriteFunctionArgument(Mutable<ILogicalExpression> argRef, IAType funcOutputType,
             IVariableTypeEnvironment env) throws AlgebricksException {
         ILogicalExpression argExpr = argRef.getValue();
         IAType type = (IAType) env.getType(argExpr);
@@ -152,17 +143,13 @@
         return false;
     }
 
-    private static boolean argumentsNeedCasting(IAType functionProducedType) {
-        ATypeTag functionProducedTag = TypeComputeUtils.getActualType(functionProducedType).getTypeTag();
-        return functionProducedTag == ATypeTag.ANY || functionProducedTag.isDerivedType();
+    public static boolean switchResultArg(int argIdx, int numArguments) {
+        // e.g. switch(cond, exp1, res1, exp2, res2, def_res)
+        return argIdx > 1 && (argIdx % 2 == 0 || argIdx == numArguments - 1);
     }
 
-    private static int switchIncrement(int currentArgIndex, int numArguments) {
-        return currentArgIndex + 2 == numArguments ? 1 : 2;
+    @FunctionalInterface
+    public interface BiIntPredicate {
+        boolean test(int argIndex, int numArguments);
     }
-
-    @SuppressWarnings("squid:S1172") // unused parameter
-    private static int increment(int currentArgIndex, int numArguments) {
-        return 1;
-    }
-}
+}
\ No newline at end of file
diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/SwitchCaseComputer.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/SwitchCaseComputer.java
index 59bfddf..db43bc7 100644
--- a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/SwitchCaseComputer.java
+++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/SwitchCaseComputer.java
@@ -18,6 +18,8 @@
  */
 package org.apache.asterix.om.typecomputer.impl;
 
+import static org.apache.asterix.om.types.BuiltinType.ANULL;
+
 import java.util.ArrayList;
 import java.util.List;
 
@@ -45,12 +47,11 @@
         AbstractFunctionCallExpression fce = (AbstractFunctionCallExpression) expression;
         String funcName = fce.getFunctionIdentifier().getName();
 
-        int argNumber = fce.getArguments().size();
-        if (argNumber < 3) {
-            throw new CompilationException(ErrorCode.COMPILATION_INVALID_PARAMETER_NUMBER, fce.getSourceLocation(),
-                    funcName, argNumber);
-        }
         int argSize = fce.getArguments().size();
+        if (argSize < 3) {
+            throw new CompilationException(ErrorCode.COMPILATION_INVALID_PARAMETER_NUMBER, fce.getSourceLocation(),
+                    funcName, argSize);
+        }
         List<IAType> types = new ArrayList<>();
         // Collects different branches' return types.
         // The last return expression is from the ELSE branch and it is optional.
@@ -58,6 +59,15 @@
             IAType type = (IAType) env.getType(fce.getArguments().get(argIndex).getValue());
             types.add(type);
         }
+        // TODO(ali): investigate if needed for CASE. assumption seems to be that CASE is always rewritten with default
+        if (addDefaultNull(argSize)) {
+            types.add(ANULL);
+        }
         return TypeResolverUtil.resolve(types);
     }
+
+    private boolean addDefaultNull(int argSize) {
+        // null is the default value for odd arg size (e.g. fun(cond_exp, exp1, res1))
+        return argSize % 2 != 0;
+    }
 }