fix a bugs related to SQL type inference return type nullability (#11327)

* fix a bunch of type inference nullability bugs

* fixes

* style

* fix test

* fix concat
diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java
index 711acad..a0b8e59 100644
--- a/core/src/main/java/org/apache/druid/math/expr/Function.java
+++ b/core/src/main/java/org/apache/druid/math/expr/Function.java
@@ -2721,6 +2721,9 @@
     @Override
     protected ExprEval eval(String x, int y)
     {
+      if (x == null) {
+        return ExprEval.of(null);
+      }
       return ExprEval.of(y < 1 ? NullHandling.defaultStringValue() : StringUtils.repeat(x, y));
     }
   }
diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
index 5ded90f..8ee9898 100644
--- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
+++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
@@ -596,6 +596,16 @@
     assertExpr("bitwiseConvertDoubleToLongBits(null)", null);
   }
 
+  @Test
+  public void testRepeat()
+  {
+    assertExpr("repeat('hello', 2)", "hellohello");
+    assertExpr("repeat('hello', -1)", null);
+    assertExpr("repeat(null, 10)", null);
+    assertExpr("repeat(nonexistent, 10)", null);
+  }
+
+
   private void assertExpr(final String expression, @Nullable final Object expectedResult)
   {
     final Expr expr = Parser.parse(expression, ExprMacroTable.nil());
diff --git a/docs/querying/sql.md b/docs/querying/sql.md
index 6f837de..12d0da9 100644
--- a/docs/querying/sql.md
+++ b/docs/querying/sql.md
@@ -303,7 +303,7 @@
 In SQL compatible mode (`false`), NULLs are treated more closely to the SQL standard. The property affects both storage
 and querying, so for correct behavior, it should be set on all Druid service types to be available at both ingestion
 time and query time. There is some overhead associated with the ability to handle NULLs; see
-the [segment internals](../design/segments.md#sql-compatible-null-handling)documentation for more details.
+the [segment internals](../design/segments.md#sql-compatible-null-handling) documentation for more details.
 
 ## Aggregation functions
 
diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java
index e7c39d0..de1293a 100644
--- a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java
+++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java
@@ -99,7 +99,11 @@
     @Override
     public ExprEval eval(final ObjectBinding bindings)
     {
-      return ExprEval.of(chronology.add(period, args.get(0).eval(bindings).asLong(), step));
+      ExprEval timestamp = args.get(0).eval(bindings);
+      if (timestamp.isNumericNull()) {
+        return ExprEval.of(null);
+      }
+      return ExprEval.of(chronology.add(period, timestamp.asLong(), step));
     }
 
     @Override
@@ -128,10 +132,14 @@
     @Override
     public ExprEval eval(final ObjectBinding bindings)
     {
+      ExprEval timestamp = args.get(0).eval(bindings);
+      if (timestamp.isNumericNull()) {
+        return ExprEval.of(null);
+      }
       final Period period = getPeriod(args, bindings);
       final Chronology chronology = getTimeZone(args, bindings);
       final int step = getStep(args, bindings);
-      return ExprEval.of(chronology.add(period, args.get(0).eval(bindings).asLong(), step));
+      return ExprEval.of(chronology.add(period, timestamp.asLong(), step));
     }
 
     @Override
diff --git a/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java
index c4710f9..05945b1 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java
@@ -20,6 +20,7 @@
 package org.apache.druid.query.expression;
 
 import com.google.common.collect.ImmutableList;
+import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.java.util.common.DateTimes;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.math.expr.Expr;
@@ -219,6 +220,24 @@
     );
   }
 
+  @Test
+  public void testNull()
+  {
+    Expr expr = apply(
+        ImmutableList.of(
+            ExprEval.ofLong(null).toExpr(),
+            ExprEval.of("P1M").toExpr(),
+            ExprEval.of(1L).toExpr()
+        )
+    );
+
+    if (NullHandling.replaceWithDefault()) {
+      Assert.assertEquals(2678400000L, expr.eval(ExprUtils.nilBindings()).value());
+    } else {
+      Assert.assertNull(expr.eval(ExprUtils.nilBindings()).value());
+    }
+  }
+
   private static class NotLiteralExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
   {
     NotLiteralExpr(String name)
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
index 6f060f1..36607bf 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
@@ -48,6 +48,7 @@
 import org.apache.calcite.sql.type.SqlReturnTypeInference;
 import org.apache.calcite.sql.type.SqlTypeFamily;
 import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.type.SqlTypeTransforms;
 import org.apache.calcite.util.Static;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.ISE;
@@ -255,11 +256,12 @@
     }
 
     /**
-     * Sets the return type of the operator to "typeName", marked as non-nullable.
+     * Sets the return type of the operator to "typeName", marked as non-nullable. If this method is used it implies the
+     * operator should never, ever, return null.
      *
-     * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or
-     * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods
-     * cannot be mixed; you must call exactly one.
+     * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
+     * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
+     * calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
      */
     public OperatorBuilder returnTypeNonNull(final SqlTypeName typeName)
     {
@@ -274,9 +276,9 @@
     /**
      * Sets the return type of the operator to "typeName", marked as nullable.
      *
-     * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or
-     * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods
-     * cannot be mixed; you must call exactly one.
+     * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
+     * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
+     * calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
      */
     public OperatorBuilder returnTypeNullable(final SqlTypeName typeName)
     {
@@ -287,12 +289,27 @@
       );
       return this;
     }
+
+    /**
+     * Sets the return type of the operator to "typeName", marked as nullable if any of its operands are nullable.
+     *
+     * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
+     * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
+     * calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
+     */
+    public OperatorBuilder returnTypeCascadeNullable(final SqlTypeName typeName)
+    {
+      Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times");
+      this.returnTypeInference = ReturnTypes.cascade(ReturnTypes.explicit(typeName), SqlTypeTransforms.TO_NULLABLE);
+      return this;
+    }
+
     /**
      * Sets the return type of the operator to an array type with elements of "typeName", marked as nullable.
      *
-     * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or
-     * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods
-     * cannot be mixed; you must call exactly one.
+     * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
+     * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
+     * calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
      */
     public OperatorBuilder returnTypeNullableArray(final SqlTypeName elementTypeName)
     {
@@ -308,9 +325,9 @@
     /**
      * Provides customized return type inference logic.
      *
-     * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or
-     * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods
-     * cannot be mixed; you must call exactly one.
+     * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)}
+     * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before
+     * calling {@link #build()}. These methods cannot be mixed; you must call exactly one.
      */
     public OperatorBuilder returnTypeInference(final SqlReturnTypeInference returnTypeInference)
     {
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java
index 073d935..9e67cc3 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java
@@ -43,7 +43,7 @@
           )
       )
       .functionCategory(SqlFunctionCategory.STRING)
-      .returnTypeNonNull(SqlTypeName.INTEGER)
+      .returnTypeCascadeNullable(SqlTypeName.INTEGER)
       .build();
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java
index 51cad2f..ca026c5 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java
@@ -47,7 +47,7 @@
           )
       )
       .functionCategory(SqlFunctionCategory.STRING)
-      .returnTypeNonNull(SqlTypeName.INTEGER)
+      .returnTypeNullable(SqlTypeName.INTEGER)
       .build();
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java
index 12edb57..dfc1501 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java
@@ -47,7 +47,7 @@
           )
       )
       .functionCategory(SqlFunctionCategory.STRING)
-      .returnTypeNonNull(SqlTypeName.INTEGER)
+      .returnTypeCascadeNullable(SqlTypeName.INTEGER)
       .build();
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java
index 5d316a5..285993b 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java
@@ -47,7 +47,7 @@
           )
       )
       .functionCategory(SqlFunctionCategory.STRING)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .build();
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java
index d77c20b..648d54b 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java
@@ -37,7 +37,7 @@
   private static final SqlFunction SQL_FUNCTION = OperatorConversions
       .operatorBuilder("BTRIM")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .functionCategory(SqlFunctionCategory.STRING)
       .requiredOperands(1)
       .build();
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java
index e7dbf50..7ffc47d 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java
@@ -22,29 +22,22 @@
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlFunction;
 import org.apache.calcite.sql.SqlFunctionCategory;
-import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.type.OperandTypes;
-import org.apache.calcite.sql.type.ReturnTypes;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.sql.calcite.expression.DruidExpression;
 import org.apache.druid.sql.calcite.expression.OperatorConversions;
 import org.apache.druid.sql.calcite.expression.SqlOperatorConversion;
-import org.apache.druid.sql.calcite.planner.Calcites;
 import org.apache.druid.sql.calcite.planner.PlannerContext;
 
 public class ConcatOperatorConversion implements SqlOperatorConversion
 {
-  private static final SqlFunction SQL_FUNCTION = new SqlFunction(
-      "CONCAT",
-      SqlKind.OTHER_FUNCTION,
-      ReturnTypes.explicit(
-          factory -> Calcites.createSqlType(factory, SqlTypeName.VARCHAR)
-      ),
-      null,
-      OperandTypes.SAME_VARIADIC,
-      SqlFunctionCategory.STRING
-  );
+  private static final SqlFunction SQL_FUNCTION = OperatorConversions
+      .operatorBuilder("CONCAT")
+      .operandTypeChecker(OperandTypes.SAME_VARIADIC)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
+      .functionCategory(SqlFunctionCategory.STRING)
+      .build();
 
   @Override
   public SqlFunction calciteOperator()
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java
index f496e0a..574fa2f 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java
@@ -67,7 +67,7 @@
       .operatorBuilder("DATE_TRUNC")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP)
       .requiredOperands(2)
-      .returnTypeNonNull(SqlTypeName.TIMESTAMP)
+      .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
       .functionCategory(SqlFunctionCategory.TIMEDATE)
       .build();
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java
index 2d13b02..3d98d3e 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java
@@ -37,7 +37,7 @@
   private static final SqlFunction SQL_FUNCTION = OperatorConversions
       .operatorBuilder("LPAD")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .functionCategory(SqlFunctionCategory.STRING)
       .requiredOperands(2)
       .build();
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java
index 70ec0c9..233ded0 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java
@@ -37,7 +37,7 @@
   private static final SqlFunction SQL_FUNCTION = OperatorConversions
       .operatorBuilder("LTRIM")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .functionCategory(SqlFunctionCategory.STRING)
       .requiredOperands(1)
       .build();
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java
index 252343c..deeffa5 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java
@@ -39,7 +39,7 @@
       .operatorBuilder("LEFT")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
       .functionCategory(SqlFunctionCategory.STRING)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .build();
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java
index e8b8e74..2456f05 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java
@@ -39,7 +39,7 @@
   private static final SqlFunction SQL_FUNCTION = OperatorConversions
       .operatorBuilder("MILLIS_TO_TIMESTAMP")
       .operandTypes(SqlTypeFamily.EXACT_NUMERIC)
-      .returnTypeNonNull(SqlTypeName.TIMESTAMP)
+      .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
       .functionCategory(SqlFunctionCategory.TIMEDATE)
       .build();
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java
index 9fd710f..4de2000 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java
@@ -38,7 +38,7 @@
   private static final SqlFunction SQL_FUNCTION = OperatorConversions
       .operatorBuilder(NAME)
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
-      .returnTypeNonNull(SqlTypeName.BIGINT)
+      .returnTypeCascadeNullable(SqlTypeName.BIGINT)
       .functionCategory(SqlFunctionCategory.STRING)
       .requiredOperands(1)
       .build();
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java
index 47c8ead..5ab8454 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java
@@ -37,7 +37,7 @@
   private static final SqlFunction SQL_FUNCTION = OperatorConversions
       .operatorBuilder("RPAD")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .functionCategory(SqlFunctionCategory.STRING)
       .requiredOperands(2)
       .build();
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java
index 6aa8f1b..bc96610 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java
@@ -37,7 +37,7 @@
   private static final SqlFunction SQL_FUNCTION = OperatorConversions
       .operatorBuilder("RTRIM")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .functionCategory(SqlFunctionCategory.STRING)
       .requiredOperands(1)
       .build();
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java
index 9521a04..55b01be 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java
@@ -39,7 +39,7 @@
       .operatorBuilder("REPEAT")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
       .functionCategory(SqlFunctionCategory.STRING)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .build();
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java
index 70280ab..6014231 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java
@@ -37,7 +37,7 @@
       .operatorBuilder("REVERSE")
       .operandTypes(SqlTypeFamily.CHARACTER)
       .functionCategory(SqlFunctionCategory.STRING)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .build();
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java
index 863bbcc..5f454a5 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java
@@ -39,7 +39,7 @@
       .operatorBuilder("RIGHT")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)
       .functionCategory(SqlFunctionCategory.STRING)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .build();
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java
index b2aabbb..133d622 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java
@@ -42,7 +42,7 @@
       .operatorBuilder("STRING_FORMAT")
       .operandTypeChecker(new StringFormatOperandTypeChecker())
       .functionCategory(SqlFunctionCategory.STRING)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .build();
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java
index e18c089..c36405f 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java
@@ -38,7 +38,7 @@
       .operatorBuilder("STRPOS")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
       .functionCategory(SqlFunctionCategory.STRING)
-      .returnTypeNonNull(SqlTypeName.INTEGER)
+      .returnTypeCascadeNullable(SqlTypeName.INTEGER)
       .build();
 
   @Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java
index ee160d6..c44375c 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java
@@ -36,7 +36,7 @@
       .operatorBuilder("textcat")
       .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
       .requiredOperands(2)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .functionCategory(SqlFunctionCategory.STRING)
       .build();
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java
index 81b2dfa..359612c 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java
@@ -41,7 +41,7 @@
       .operatorBuilder("TIME_CEIL")
       .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)
       .requiredOperands(2)
-      .returnTypeNonNull(SqlTypeName.TIMESTAMP)
+      .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
       .functionCategory(SqlFunctionCategory.TIMEDATE)
       .build();
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java
index 35accd1..000923c 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java
@@ -44,7 +44,7 @@
       .operatorBuilder("TIME_EXTRACT")
       .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
       .requiredOperands(2)
-      .returnTypeNonNull(SqlTypeName.BIGINT)
+      .returnTypeCascadeNullable(SqlTypeName.BIGINT)
       .functionCategory(SqlFunctionCategory.TIMEDATE)
       .build();
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java
index 87c07f2..20377a0 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java
@@ -56,7 +56,7 @@
       .operatorBuilder("TIME_FLOOR")
       .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)
       .requiredOperands(2)
-      .returnTypeNonNull(SqlTypeName.TIMESTAMP)
+      .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
       .functionCategory(SqlFunctionCategory.TIMEDATE)
       .build();
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java
index 1f7b6f9..e44734f 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java
@@ -47,7 +47,7 @@
       .operatorBuilder("TIME_FORMAT")
       .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)
       .requiredOperands(1)
-      .returnTypeNonNull(SqlTypeName.VARCHAR)
+      .returnTypeCascadeNullable(SqlTypeName.VARCHAR)
       .functionCategory(SqlFunctionCategory.TIMEDATE)
       .build();
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java
index 25b05c4..a4fd210 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java
@@ -45,7 +45,7 @@
       .operatorBuilder("TIME_SHIFT")
       .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER)
       .requiredOperands(3)
-      .returnTypeNonNull(SqlTypeName.TIMESTAMP)
+      .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP)
       .functionCategory(SqlFunctionCategory.TIMEDATE)
       .build();
 
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java
index ae45655..ece14e2 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java
@@ -39,7 +39,7 @@
   private static final SqlFunction SQL_FUNCTION = OperatorConversions
       .operatorBuilder("TIMESTAMP_TO_MILLIS")
       .operandTypes(SqlTypeFamily.TIMESTAMP)
-      .returnTypeNonNull(SqlTypeName.BIGINT)
+      .returnTypeCascadeNullable(SqlTypeName.BIGINT)
       .functionCategory(SqlFunctionCategory.TIMEDATE)
       .build();
 
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index f724cd0..5014a35 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -17562,4 +17562,53 @@
                                .build()),
         ImmutableList.of(new Object[]{6L}));
   }
+
+  @Test
+  public void testExpressionCounts() throws Exception
+  {
+    cannotVectorize();
+    testQuery(
+        "SELECT\n"
+        + " COUNT(reverse(dim2)),\n"
+        + " COUNT(left(dim2, 5)),\n"
+        + " COUNT(strpos(dim2, 'a'))\n"
+        + "FROM druid.numfoo",
+        ImmutableList.of(
+            Druids.newTimeseriesQueryBuilder()
+                  .dataSource(CalciteTests.DATASOURCE3)
+                  .intervals(querySegmentSpec(Filtration.eternity()))
+                  .granularity(Granularities.ALL)
+                  .virtualColumns(
+                      expressionVirtualColumn("v0", "reverse(\"dim2\")", ValueType.STRING),
+                      expressionVirtualColumn("v1", "left(\"dim2\",5)", ValueType.STRING),
+                      expressionVirtualColumn("v2", "(strpos(\"dim2\",'a') + 1)", ValueType.LONG)
+                  )
+                  .aggregators(
+                      aggregators(
+                          new FilteredAggregatorFactory(
+                              new CountAggregatorFactory("a0"),
+                              not(selector("v0", null, null))
+                          ),
+                          new FilteredAggregatorFactory(
+                              new CountAggregatorFactory("a1"),
+                              not(selector("v1", null, null))
+                          ),
+                          new FilteredAggregatorFactory(
+                              new CountAggregatorFactory("a2"),
+                              not(selector("v2", null, null))
+                          )
+                      )
+                  )
+                  .context(QUERY_CONTEXT_DEFAULT)
+                  .build()
+        ),
+        ImmutableList.of(
+            useDefault
+            // in default mode strpos is 6 because the '+ 1' of the expression (no null numbers in
+            // default mode so is 0 + 1 for null rows)
+            ? new Object[]{3L, 3L, 6L}
+            : new Object[]{4L, 4L, 4L}
+        )
+    );
+  }
 }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java
index 0268bb6..5f70dc5 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java
@@ -31,12 +31,14 @@
 import org.apache.calcite.sql.SqlOperandCountRange;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.sql.type.SqlOperandTypeChecker;
+import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
 import org.apache.calcite.sql.type.SqlTypeFamily;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.validate.SqlValidator;
 import org.apache.calcite.sql.validate.SqlValidatorScope;
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker;
+import org.apache.druid.sql.calcite.planner.DruidTypeSystem;
 import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
@@ -276,6 +278,69 @@
     }
 
     @Test
+    public void testNullForNullableOperandNonNullOutput()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testNullForNullableNonnull")
+          .operandTypes(SqlTypeFamily.CHARACTER)
+          .requiredOperands(1)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      SqlCallBinding binding = mockCallBinding(
+          function,
+          ImmutableList.of(
+              new OperandSpec(SqlTypeName.CHAR, false, true)
+          )
+      );
+      Assert.assertTrue(typeChecker.checkOperandTypes(binding, true));
+      RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding);
+      Assert.assertFalse(returnType.isNullable());
+    }
+
+    @Test
+    public void testNullForNullableOperandCascadeNullOutput()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testNullForNullableCascade")
+          .operandTypes(SqlTypeFamily.CHARACTER)
+          .requiredOperands(1)
+          .returnTypeCascadeNullable(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      SqlCallBinding binding = mockCallBinding(
+          function,
+          ImmutableList.of(
+              new OperandSpec(SqlTypeName.CHAR, false, true)
+          )
+      );
+      Assert.assertTrue(typeChecker.checkOperandTypes(binding, true));
+      RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding);
+      Assert.assertTrue(returnType.isNullable());
+    }
+
+    @Test
+    public void testNullForNullableOperandAlwaysNullableOutput()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testNullForNullableNonnull")
+          .operandTypes(SqlTypeFamily.CHARACTER)
+          .requiredOperands(1)
+          .returnTypeNullable(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      SqlCallBinding binding = mockCallBinding(
+          function,
+          ImmutableList.of(
+              new OperandSpec(SqlTypeName.CHAR, false, false)
+          )
+      );
+      Assert.assertTrue(typeChecker.checkOperandTypes(binding, true));
+      RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding);
+      Assert.assertTrue(returnType.isNullable());
+    }
+
+    @Test
     public void testNullForNonNullableOperand()
     {
       SqlFunction function = OperatorConversions
@@ -359,6 +424,7 @@
     )
     {
       SqlValidator validator = Mockito.mock(SqlValidator.class);
+      Mockito.when(validator.getTypeFactory()).thenReturn(new SqlTypeFactoryImpl(DruidTypeSystem.INSTANCE));
       List<SqlNode> operands = new ArrayList<>(actualOperands.size());
       for (OperandSpec operand : actualOperands) {
         final SqlNode node;
@@ -368,6 +434,12 @@
           node = Mockito.mock(SqlNode.class);
         }
         RelDataType relDataType = Mockito.mock(RelDataType.class);
+
+        if (operand.isNullable) {
+          Mockito.when(relDataType.isNullable()).thenReturn(true);
+        } else {
+          Mockito.when(relDataType.isNullable()).thenReturn(false);
+        }
         Mockito.when(validator.deriveType(ArgumentMatchers.any(), ArgumentMatchers.eq(node)))
                .thenReturn(relDataType);
         Mockito.when(relDataType.getSqlTypeName()).thenReturn(operand.type);
@@ -394,11 +466,18 @@
     {
       private final SqlTypeName type;
       private final boolean isLiteral;
+      private final boolean isNullable;
 
       private OperandSpec(SqlTypeName type, boolean isLiteral)
       {
+        this(type, isLiteral, type == SqlTypeName.NULL);
+      }
+
+      private OperandSpec(SqlTypeName type, boolean isLiteral, boolean isNullable)
+      {
         this.type = type;
         this.isLiteral = isLiteral;
+        this.isNullable = isNullable;
       }
     }
   }