[Improve](connector) fix V2Expression build error (#339)
Fix V2Expression build error with detailMessage = Unknown column 'null' in 'table list'
---------
Co-authored-by: wangguoxing <wangguoxing@kingsoft.com>
diff --git a/spark-doris-connector/spark-doris-connector-it/src/test/java/org/apache/doris/spark/sql/DorisReaderITCase.scala b/spark-doris-connector/spark-doris-connector-it/src/test/java/org/apache/doris/spark/sql/DorisReaderITCase.scala
index 5fa9112..2ed6188 100644
--- a/spark-doris-connector/spark-doris-connector-it/src/test/java/org/apache/doris/spark/sql/DorisReaderITCase.scala
+++ b/spark-doris-connector/spark-doris-connector-it/src/test/java/org/apache/doris/spark/sql/DorisReaderITCase.scala
@@ -56,6 +56,7 @@
val TABLE_READ_UTF8_TBL = "tbl_read_utf8_tbl"
val TABLE_READ_TBL_ALL_TYPES = "tbl_read_tbl_all_types"
val TABLE_READ_TBL_BIT_MAP = "tbl_read_tbl_bitmap"
+ val TABLE_READ_EXPRESSION_NOTPUSHDOWN = "tbl_expression_notpushdown"
@Before
def setUp(): Unit = {
@@ -560,4 +561,76 @@
assert("List([1], [2])".equals(prefixTest.toList.toString()))
}
}
+
+ @Test
+ def testExpressionNotPushDown(): Unit = {
+ val sourceInitSql: Array[String] = ContainerUtils.parseFileContentSQL("container/ddl/read_filter_pushdown.sql")
+ ContainerUtils.executeSQLStatement(getDorisQueryConnection(DATABASE), LOG, sourceInitSql: _*)
+
+ val session = SparkSession.builder().master("local[*]").getOrCreate()
+ try {
+ session.sql(
+ s"""
+ |CREATE TEMPORARY VIEW test_source
+ |USING doris
+ |OPTIONS(
+ | "table.identifier"="${DATABASE + "." + TABLE_READ_EXPRESSION_NOTPUSHDOWN}",
+ | "fenodes"="${getFenodes}",
+ | "user"="${getDorisUsername}",
+ | "password"="${getDorisPassword}"
+ |)
+ |""".stripMargin)
+
+ val resultData = session.sql(
+ """
+ |select COALESCE(CAST(A4 AS STRING),'null')
+ |from test_source where COALESCE(CAST(A4 AS STRING),'null') in ('a4')
+ |""".stripMargin)
+
+ println(resultData.collect().toList.toString())
+ assert("List([a4], [a4], [a4], [a4], [a4])".equals(resultData.collect().toList.toString()))
+
+ val resultData1 = session.sql(
+ """
+ |select COALESCE(CAST(NAME AS STRING),'null')
+ |from test_source where COALESCE(CAST(NAME AS STRING),'null') in ('name2')
+ |""".stripMargin)
+
+ assert("List([name2])".equals(resultData1.collect().toList.toString()))
+ } finally {
+ session.stop()
+ }
+ }
+
+ @Test
+ def testDataFrameExpressionNotPushDown(): Unit = {
+ val sourceInitSql: Array[String] = ContainerUtils.parseFileContentSQL("container/ddl/read_filter_pushdown.sql")
+ ContainerUtils.executeSQLStatement(getDorisQueryConnection(DATABASE), LOG, sourceInitSql: _*)
+
+ val session = SparkSession.builder().master("local[*]").getOrCreate()
+ try {
+ val df = session.read
+ .format("doris")
+ .option("doris.fenodes", getFenodes)
+ .option("doris.table.identifier", DATABASE + "." + TABLE_READ_EXPRESSION_NOTPUSHDOWN)
+ .option("user", getDorisUsername)
+ .option("password", getDorisPassword)
+ .load()
+
+ import org.apache.spark.sql.functions._
+
+ val resultData = df.select(coalesce(col("A4").cast("string"), lit("null")).as("coalesced_A4"))
+ .filter(col("coalesced_A4").isin("a4"))
+
+ println(resultData.collect().toList.toString())
+ assert("List([a4], [a4], [a4], [a4], [a4])".equals(resultData.collect().toList.toString()))
+
+ val resultData1 = df.select(coalesce(col("NAME").cast("string"), lit("null")).as("coalesced_NAME"))
+ .filter(col("coalesced_NAME").isin("name2"))
+
+ assert("List([name2])".equals(resultData1.collect().toList.toString()))
+ } finally {
+ session.stop()
+ }
+ }
}
diff --git a/spark-doris-connector/spark-doris-connector-it/src/test/resources/container/ddl/read_filter_pushdown.sql b/spark-doris-connector/spark-doris-connector-it/src/test/resources/container/ddl/read_filter_pushdown.sql
new file mode 100644
index 0000000..a2176de
--- /dev/null
+++ b/spark-doris-connector/spark-doris-connector-it/src/test/resources/container/ddl/read_filter_pushdown.sql
@@ -0,0 +1,29 @@
+DROP TABLE IF EXISTS tbl_expression_notpushdown;
+CREATE TABLE tbl_expression_notpushdown (
+ ID decimal(38,10) NULL,
+ NAME varchar(300) NULL,
+ AGE decimal(38,10) NULL,
+ CREATE_TIME datetime(3) NULL,
+ A1 varchar(300) NULL,
+ A2 varchar(300) NULL,
+ A3 varchar(300) NULL,
+ A4 varchar(300) NULL,
+ __source_ts_ms bigint NULL,
+ __op varchar(10) NULL,
+ __table varchar(100) NULL,
+ __db varchar(50) NULL,
+ __deleted varchar(10) NULL,
+ __dt datetime NULL
+) ENGINE=OLAP
+DUPLICATE KEY(ID)
+DISTRIBUTED BY HASH(`ID`) BUCKETS 2
+PROPERTIES (
+"replication_num" = "1",
+"light_schema_change" = "true"
+);
+
+insert into tbl_expression_notpushdown values(1, 'name1', 18, '2021-01-01 00:00:00.000', 'a1', 'a2', 'a3', 'a4', 1609459200000, 'c', 'tbl_read_tbl_all_types', 'db_read_tbl_all_types', false, '2021-01-01 00:00:00');
+insert into tbl_expression_notpushdown values(2, 'name2', 19, '2021-01-01 00:00:00.000', 'a1', 'a2', 'a3', 'a4', 1609459200000, 'c', 'tbl_read_tbl_all_types', 'db_read_tbl_all_types', false, '2021-01-01 00:00:00');
+insert into tbl_expression_notpushdown values(3, 'name3', 20, '2021-01-01 00:00:00.000', 'a1', 'a2', 'a3', 'a4', 1609459200000, 'c', 'tbl_read_tbl_all_types', 'db_read_tbl_all_types', false, '2021-01-01 00:00:00');
+insert into tbl_expression_notpushdown values(4, 'name4', 21, '2021-01-01 00:00:00.000', 'a1', 'a2', 'a3', 'a4', 1609459200000, 'c', 'tbl_read_tbl_all_types', 'db_read_tbl_all_types', false, '2021-01-01 00:00:00');
+insert into tbl_expression_notpushdown values(5, 'name5', 22, '2021-01-01 00:00:00.000', 'a1', 'a2', 'a3', 'a4', 1609459200000, 'c', 'tbl_read_tbl_all_types', 'db_read_tbl_all_types', false, '2021-01-01 00:00:00');
\ No newline at end of file
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
index 61a9e20..26b45f4 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
@@ -37,7 +37,7 @@
override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
val (pushed, unsupported) = predicates.partition(predicate => {
- Option(expressionBuilder.build(predicate)).isDefined
+ expressionBuilder.buildOpt(predicate).isDefined
})
this.pushDownPredicates = pushed
unsupported
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
index 503ad04..5e0104b 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
@@ -27,7 +27,7 @@
override protected def compiledFilters(): Array[String] = {
val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)
val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit)
- filters.map(e => Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get)
+ filters.map(e => v2ExpressionBuilder.buildOpt(e)).filter(_.isDefined).map(_.get)
}
override protected def getLimit: Int = limit
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala
index cba20f1..880784c 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala
@@ -17,12 +17,26 @@
package org.apache.doris.spark.read.expression
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And, Not, Or}
import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal, NamedReference}
import org.apache.spark.sql.types.{DateType, TimestampType}
-class V2ExpressionBuilder(inValueLengthLimit: Int) {
+import scala.util.{Failure, Success, Try}
+
+class V2ExpressionBuilder(inValueLengthLimit: Int) extends Logging {
+
+ def buildOpt(predicate: Expression): Option[String] = {
+ Try {
+ Some(build(predicate))
+ } match {
+ case Success(value) => value
+ case Failure(exception) =>
+ logWarning(s"Failed to build expression: ${predicate.toString}, and not support predicate push down, errMsg is ${exception.getMessage}")
+ None
+ }
+ }
def build(predicate: Expression): String = {
predicate match {
@@ -39,24 +53,24 @@
case expr: Expression =>
expr match {
case literal: Literal[_] => visitLiteral(literal)
- case namedRef: NamedReference => namedRef.toString
+ case namedRef: NamedReference => s"`${namedRef.toString}`"
case e: GeneralScalarExpression => e.name() match {
case "IN" =>
val expressions = e.children()
if (expressions.nonEmpty && expressions.length <= inValueLengthLimit) {
- s"""`${build(expressions(0))}` IN (${expressions.slice(1, expressions.length).map(build).mkString(",")})"""
- } else null
- case "IS_NULL" => s"`${build(e.children()(0))}` IS NULL"
- case "IS_NOT_NULL" => s"`${build(e.children()(0))}` IS NOT NULL"
+ s"""${build(expressions(0))} IN (${expressions.slice(1, expressions.length).map(build).mkString(",")})"""
+ } else throw new IllegalArgumentException(s"exceeding limit of IN values: actual size ${expressions.length}, limit size $inValueLengthLimit")
+ case "IS_NULL" => s"${build(e.children()(0))} IS NULL"
+ case "IS_NOT_NULL" => s"${build(e.children()(0))} IS NOT NULL"
case "STARTS_WITH" => visitStartWith(build(e.children()(0)), build(e.children()(1)));
case "ENDS_WITH" => visitEndWith(build(e.children()(0)), build(e.children()(1)));
case "CONTAINS" => visitContains(build(e.children()(0)), build(e.children()(1)));
- case "=" => s"`${build(e.children()(0))}` = ${build(e.children()(1))}"
- case "!=" | "<>" => s"`${build(e.children()(0))}` != ${build(e.children()(1))}"
- case "<" => s"`${build(e.children()(0))}` < ${build(e.children()(1))}"
- case "<=" => s"`${build(e.children()(0))}` <= ${build(e.children()(1))}"
- case ">" => s"`${build(e.children()(0))}` > ${build(e.children()(1))}"
- case ">=" => s"`${build(e.children()(0))}` >= ${build(e.children()(1))}"
+ case "=" => s"${build(e.children()(0))} = ${build(e.children()(1))}"
+ case "!=" | "<>" => s"${build(e.children()(0))} != ${build(e.children()(1))}"
+ case "<" => s"${build(e.children()(0))} < ${build(e.children()(1))}"
+ case "<=" => s"${build(e.children()(0))} <= ${build(e.children()(1))}"
+ case ">" => s"${build(e.children()(0))} > ${build(e.children()(1))}"
+ case ">=" => s"${build(e.children()(0))} >= ${build(e.children()(1))}"
case "CASE_WHEN" =>
val fragment = new StringBuilder("CASE ")
val expressions = e.children()
@@ -72,7 +86,7 @@
fragment.append(" END")
fragment.mkString
- case _ => null
+ case _ => throw new IllegalArgumentException(s"Unsupported expression: ${e.name()}")
}
}
}
@@ -91,17 +105,17 @@
def visitStartWith(l: String, r: String): String = {
val value = r.substring(1, r.length - 1)
- s"`$l` LIKE '$value%'"
+ s"$l LIKE '$value%'"
}
def visitEndWith(l: String, r: String): String = {
val value = r.substring(1, r.length - 1)
- s"`$l` LIKE '%$value'"
+ s"$l LIKE '%$value'"
}
def visitContains(l: String, r: String): String = {
val value = r.substring(1, r.length - 1)
- s"`$l` LIKE '%$value%'"
+ s"$l LIKE '%$value%'"
}
}
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.3/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala b/spark-doris-connector/spark-doris-connector-spark-3.3/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala
index fc29495..6e72689 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.3/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.3/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala
@@ -17,6 +17,8 @@
// specific language governing permissions and limitations
// under the License.
+
+import org.apache.spark.sql.ExpressionUtil
import org.apache.spark.sql.sources._
import org.junit.jupiter.api.{Assertions, Test}
@@ -42,8 +44,23 @@
Assertions.assertEquals(builder.build(Or(EqualTo("c14", 17), EqualTo("c15", 18)).toV2), "(`c14` = 17 OR `c15` = 18)")
Assertions.assertEquals(builder.build(AlwaysTrue.toV2), "1=1")
Assertions.assertEquals(builder.build(AlwaysFalse.toV2), "1=0")
- Assertions.assertNull(builder.build(In("c19", Array(19,20,21,22,23,24,25,26,27,28,29)).toV2))
+ Assertions.assertEquals(builder.build(In("c19", Array(19,20,21,22,23,24,25,26)).toV2), "`c19` IN (19,20,21,22,23,24,25,26)")
+ Assertions.assertEquals(builder.build(In("c19", Array("19","20")).toV2), "`c19` IN ('19','20')")
+ val inException = Assertions.assertThrows(classOf[IllegalArgumentException], () => builder.build(In("c19", Array(19,20,21,22,23,24,25,26,27,28,29)).toV2))
+ Assertions.assertEquals(inException.getMessage, "exceeding limit of IN values: actual size 12, limit size 10")
+ val exception = Assertions.assertThrows(classOf[IllegalArgumentException], () => builder.build(ExpressionUtil.buildCoalesceFilter()))
+ Assertions.assertEquals(exception.getMessage, "Unsupported expression: COALESCE")
}
+ @Test
+ def buildOptTest() : Unit = {
+
+ val builder = new V2ExpressionBuilder(10)
+ Assertions.assertEquals(builder.buildOpt(EqualTo("c0", 1).toV2), Some("`c0` = 1"))
+ Assertions.assertEquals(builder.buildOpt(Not(EqualTo("c1", 2)).toV2), Some("`c1` != 2"))
+ Assertions.assertEquals(builder.buildOpt(GreaterThan("c2", 3.4).toV2), Some("`c2` > 3.4"))
+ Assertions.assertEquals(builder.buildOpt(ExpressionUtil.buildCoalesceFilter()), None)
+ }
+
}
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.3/src/test/scala/org/apache/spark/sql/ExpressionUtil.scala b/spark-doris-connector/spark-doris-connector-spark-3.3/src/test/scala/org/apache/spark/sql/ExpressionUtil.scala
new file mode 100644
index 0000000..84b8beb
--- /dev/null
+++ b/spark-doris-connector/spark-doris-connector-spark-3.3/src/test/scala/org/apache/spark/sql/ExpressionUtil.scala
@@ -0,0 +1,31 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, GeneralScalarExpression, LiteralValue}
+import org.apache.spark.sql.types.StringType
+
+object ExpressionUtil {
+
+ def buildCoalesceFilter(): Expression = {
+ val gse = new GeneralScalarExpression("COALESCE", Array(FieldReference(Seq("A4")), LiteralValue("null", StringType)))
+ new Predicate("=", Array(gse, LiteralValue("1", StringType)))
+ }
+
+}
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
index 61a9e20..26b45f4 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
@@ -37,7 +37,7 @@
override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
val (pushed, unsupported) = predicates.partition(predicate => {
- Option(expressionBuilder.build(predicate)).isDefined
+ expressionBuilder.buildOpt(predicate).isDefined
})
this.pushDownPredicates = pushed
unsupported
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
index 503ad04..5e0104b 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
@@ -27,7 +27,7 @@
override protected def compiledFilters(): Array[String] = {
val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)
val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit)
- filters.map(e => Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get)
+ filters.map(e => v2ExpressionBuilder.buildOpt(e)).filter(_.isDefined).map(_.get)
}
override protected def getLimit: Int = limit
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala
index cba20f1..880784c 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala
@@ -17,12 +17,26 @@
package org.apache.doris.spark.read.expression
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And, Not, Or}
import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal, NamedReference}
import org.apache.spark.sql.types.{DateType, TimestampType}
-class V2ExpressionBuilder(inValueLengthLimit: Int) {
+import scala.util.{Failure, Success, Try}
+
+class V2ExpressionBuilder(inValueLengthLimit: Int) extends Logging {
+
+ def buildOpt(predicate: Expression): Option[String] = {
+ Try {
+ Some(build(predicate))
+ } match {
+ case Success(value) => value
+ case Failure(exception) =>
+ logWarning(s"Failed to build expression: ${predicate.toString}, and not support predicate push down, errMsg is ${exception.getMessage}")
+ None
+ }
+ }
def build(predicate: Expression): String = {
predicate match {
@@ -39,24 +53,24 @@
case expr: Expression =>
expr match {
case literal: Literal[_] => visitLiteral(literal)
- case namedRef: NamedReference => namedRef.toString
+ case namedRef: NamedReference => s"`${namedRef.toString}`"
case e: GeneralScalarExpression => e.name() match {
case "IN" =>
val expressions = e.children()
if (expressions.nonEmpty && expressions.length <= inValueLengthLimit) {
- s"""`${build(expressions(0))}` IN (${expressions.slice(1, expressions.length).map(build).mkString(",")})"""
- } else null
- case "IS_NULL" => s"`${build(e.children()(0))}` IS NULL"
- case "IS_NOT_NULL" => s"`${build(e.children()(0))}` IS NOT NULL"
+ s"""${build(expressions(0))} IN (${expressions.slice(1, expressions.length).map(build).mkString(",")})"""
+ } else throw new IllegalArgumentException(s"exceeding limit of IN values: actual size ${expressions.length}, limit size $inValueLengthLimit")
+ case "IS_NULL" => s"${build(e.children()(0))} IS NULL"
+ case "IS_NOT_NULL" => s"${build(e.children()(0))} IS NOT NULL"
case "STARTS_WITH" => visitStartWith(build(e.children()(0)), build(e.children()(1)));
case "ENDS_WITH" => visitEndWith(build(e.children()(0)), build(e.children()(1)));
case "CONTAINS" => visitContains(build(e.children()(0)), build(e.children()(1)));
- case "=" => s"`${build(e.children()(0))}` = ${build(e.children()(1))}"
- case "!=" | "<>" => s"`${build(e.children()(0))}` != ${build(e.children()(1))}"
- case "<" => s"`${build(e.children()(0))}` < ${build(e.children()(1))}"
- case "<=" => s"`${build(e.children()(0))}` <= ${build(e.children()(1))}"
- case ">" => s"`${build(e.children()(0))}` > ${build(e.children()(1))}"
- case ">=" => s"`${build(e.children()(0))}` >= ${build(e.children()(1))}"
+ case "=" => s"${build(e.children()(0))} = ${build(e.children()(1))}"
+ case "!=" | "<>" => s"${build(e.children()(0))} != ${build(e.children()(1))}"
+ case "<" => s"${build(e.children()(0))} < ${build(e.children()(1))}"
+ case "<=" => s"${build(e.children()(0))} <= ${build(e.children()(1))}"
+ case ">" => s"${build(e.children()(0))} > ${build(e.children()(1))}"
+ case ">=" => s"${build(e.children()(0))} >= ${build(e.children()(1))}"
case "CASE_WHEN" =>
val fragment = new StringBuilder("CASE ")
val expressions = e.children()
@@ -72,7 +86,7 @@
fragment.append(" END")
fragment.mkString
- case _ => null
+ case _ => throw new IllegalArgumentException(s"Unsupported expression: ${e.name()}")
}
}
}
@@ -91,17 +105,17 @@
def visitStartWith(l: String, r: String): String = {
val value = r.substring(1, r.length - 1)
- s"`$l` LIKE '$value%'"
+ s"$l LIKE '$value%'"
}
def visitEndWith(l: String, r: String): String = {
val value = r.substring(1, r.length - 1)
- s"`$l` LIKE '%$value'"
+ s"$l LIKE '%$value'"
}
def visitContains(l: String, r: String): String = {
val value = r.substring(1, r.length - 1)
- s"`$l` LIKE '%$value%'"
+ s"$l LIKE '%$value%'"
}
}
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.4/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala b/spark-doris-connector/spark-doris-connector-spark-3.4/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala
index fc29495..7691b29 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.4/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.4/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala
@@ -17,6 +17,7 @@
// specific language governing permissions and limitations
// under the License.
+import org.apache.spark.sql.ExpressionUtil
import org.apache.spark.sql.sources._
import org.junit.jupiter.api.{Assertions, Test}
@@ -42,8 +43,23 @@
Assertions.assertEquals(builder.build(Or(EqualTo("c14", 17), EqualTo("c15", 18)).toV2), "(`c14` = 17 OR `c15` = 18)")
Assertions.assertEquals(builder.build(AlwaysTrue.toV2), "1=1")
Assertions.assertEquals(builder.build(AlwaysFalse.toV2), "1=0")
- Assertions.assertNull(builder.build(In("c19", Array(19,20,21,22,23,24,25,26,27,28,29)).toV2))
+ Assertions.assertEquals(builder.build(In("c19", Array(19,20,21,22,23,24,25,26)).toV2), "`c19` IN (19,20,21,22,23,24,25,26)")
+ Assertions.assertEquals(builder.build(In("c19", Array("19","20")).toV2), "`c19` IN ('19','20')")
+ val inException = Assertions.assertThrows(classOf[IllegalArgumentException], () => builder.build(In("c19", Array(19,20,21,22,23,24,25,26,27,28,29)).toV2))
+ Assertions.assertEquals(inException.getMessage, "exceeding limit of IN values: actual size 12, limit size 10")
+ val exception = Assertions.assertThrows(classOf[IllegalArgumentException], () => builder.build(ExpressionUtil.buildCoalesceFilter()))
+ Assertions.assertEquals(exception.getMessage, "Unsupported expression: COALESCE")
}
+ @Test
+ def buildOptTest() : Unit = {
+
+ val builder = new V2ExpressionBuilder(10)
+ Assertions.assertEquals(builder.buildOpt(EqualTo("c0", 1).toV2), Some("`c0` = 1"))
+ Assertions.assertEquals(builder.buildOpt(Not(EqualTo("c1", 2)).toV2), Some("`c1` != 2"))
+ Assertions.assertEquals(builder.buildOpt(GreaterThan("c2", 3.4).toV2), Some("`c2` > 3.4"))
+ Assertions.assertEquals(builder.buildOpt(ExpressionUtil.buildCoalesceFilter()), None)
+ }
+
}
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.4/src/test/scala/org/apache/spark/sql/ExpressionUtil.scala b/spark-doris-connector/spark-doris-connector-spark-3.4/src/test/scala/org/apache/spark/sql/ExpressionUtil.scala
new file mode 100644
index 0000000..2da931a
--- /dev/null
+++ b/spark-doris-connector/spark-doris-connector-spark-3.4/src/test/scala/org/apache/spark/sql/ExpressionUtil.scala
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, GeneralScalarExpression, LiteralValue}
+import org.apache.spark.sql.types.StringType
+
+object ExpressionUtil {
+
+ def buildCoalesceFilter(): Expression = {
+ val gse = new GeneralScalarExpression("COALESCE", Array(FieldReference(Seq("A4")), LiteralValue("null", StringType)))
+ new Predicate("=", Array(gse, LiteralValue("1", StringType)))
+ }
+}
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
index 61a9e20..26b45f4 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
@@ -37,7 +37,7 @@
override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
val (pushed, unsupported) = predicates.partition(predicate => {
- Option(expressionBuilder.build(predicate)).isDefined
+ expressionBuilder.buildOpt(predicate).isDefined
})
this.pushDownPredicates = pushed
unsupported
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
index 503ad04..5e0104b 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
@@ -27,7 +27,7 @@
override protected def compiledFilters(): Array[String] = {
val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)
val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit)
- filters.map(e => Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get)
+ filters.map(e => v2ExpressionBuilder.buildOpt(e)).filter(_.isDefined).map(_.get)
}
override protected def getLimit: Int = limit
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala
index 6e1c010..9a86a69 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala
@@ -17,12 +17,26 @@
package org.apache.doris.spark.read.expression
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And, Not, Or}
import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal, NamedReference}
import org.apache.spark.sql.types.{DateType, TimestampType}
-class V2ExpressionBuilder(inValueLengthLimit: Int) {
+import scala.util.{Failure, Success, Try}
+
+class V2ExpressionBuilder(inValueLengthLimit: Int) extends Logging {
+
+ def buildOpt(predicate: Expression): Option[String] = {
+ Try {
+ Some(build(predicate))
+ } match {
+ case Success(value) => value
+ case Failure(exception) =>
+ logWarning(s"Failed to build expression: ${predicate.toString}, and not support predicate push down, errMsg is ${exception.getMessage}")
+ None
+ }
+ }
def build(predicate: Expression): String = {
predicate match {
@@ -39,24 +53,24 @@
case expr: Expression =>
expr match {
case literal: Literal[_] => visitLiteral(literal)
- case namedRef: NamedReference => namedRef.toString
+ case namedRef: NamedReference => s"`${namedRef.toString}`"
case e: GeneralScalarExpression => e.name() match {
case "IN" =>
val expressions = e.children()
if (expressions.nonEmpty && expressions.length <= inValueLengthLimit) {
- s"""`${build(expressions(0))}` IN (${expressions.slice(1, expressions.length).map(build).mkString(",")})"""
- } else null
- case "IS_NULL" => s"`${build(e.children()(0))}` IS NULL"
- case "IS_NOT_NULL" => s"`${build(e.children()(0))}` IS NOT NULL"
+ s"""${build(expressions(0))} IN (${expressions.slice(1, expressions.length).map(build).mkString(",")})"""
+ } else throw new IllegalArgumentException(s"exceeding limit of IN values: actual size ${expressions.length}, limit size $inValueLengthLimit")
+ case "IS_NULL" => s"${build(e.children()(0))} IS NULL"
+ case "IS_NOT_NULL" => s"${build(e.children()(0))} IS NOT NULL"
case "STARTS_WITH" => visitStartWith(build(e.children()(0)), build(e.children()(1)));
case "ENDS_WITH" => visitEndWith(build(e.children()(0)), build(e.children()(1)));
case "CONTAINS" => visitContains(build(e.children()(0)), build(e.children()(1)));
- case "=" => s"`${build(e.children()(0))}` = ${build(e.children()(1))}"
- case "!=" | "<>" => s"`${build(e.children()(0))}` != ${build(e.children()(1))}"
- case "<" => s"`${build(e.children()(0))}` < ${build(e.children()(1))}"
- case "<=" => s"`${build(e.children()(0))}` <= ${build(e.children()(1))}"
- case ">" => s"`${build(e.children()(0))}` > ${build(e.children()(1))}"
- case ">=" => s"`${build(e.children()(0))}` >= ${build(e.children()(1))}"
+ case "=" => s"${build(e.children()(0))} = ${build(e.children()(1))}"
+ case "!=" | "<>" => s"${build(e.children()(0))} != ${build(e.children()(1))}"
+ case "<" => s"${build(e.children()(0))} < ${build(e.children()(1))}"
+ case "<=" => s"${build(e.children()(0))} <= ${build(e.children()(1))}"
+ case ">" => s"${build(e.children()(0))} > ${build(e.children()(1))}"
+ case ">=" => s"${build(e.children()(0))} >= ${build(e.children()(1))}"
case "CASE_WHEN" =>
val fragment = new StringBuilder("CASE ")
val expressions = e.children()
@@ -72,7 +86,7 @@
fragment.append(" END")
fragment.mkString
- case _ => null
+ case _ => throw new IllegalArgumentException(s"Unsupported expression: ${e.name()}")
}
}
}
@@ -90,17 +104,17 @@
}
def visitStartWith(l: String, r: String): String = {
val value = r.substring(1, r.length - 1)
- s"`$l` LIKE '$value%'"
+ s"$l LIKE '$value%'"
}
def visitEndWith(l: String, r: String): String = {
val value = r.substring(1, r.length - 1)
- s"`$l` LIKE '%$value'"
+ s"$l LIKE '%$value'"
}
def visitContains(l: String, r: String): String = {
val value = r.substring(1, r.length - 1)
- s"`$l` LIKE '%$value%'"
+ s"$l LIKE '%$value%'"
}
}
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.5/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala b/spark-doris-connector/spark-doris-connector-spark-3.5/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala
index fc29495..7691b29 100644
--- a/spark-doris-connector/spark-doris-connector-spark-3.5/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala
+++ b/spark-doris-connector/spark-doris-connector-spark-3.5/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala
@@ -17,6 +17,7 @@
// specific language governing permissions and limitations
// under the License.
+import org.apache.spark.sql.ExpressionUtil
import org.apache.spark.sql.sources._
import org.junit.jupiter.api.{Assertions, Test}
@@ -42,8 +43,23 @@
Assertions.assertEquals(builder.build(Or(EqualTo("c14", 17), EqualTo("c15", 18)).toV2), "(`c14` = 17 OR `c15` = 18)")
Assertions.assertEquals(builder.build(AlwaysTrue.toV2), "1=1")
Assertions.assertEquals(builder.build(AlwaysFalse.toV2), "1=0")
- Assertions.assertNull(builder.build(In("c19", Array(19,20,21,22,23,24,25,26,27,28,29)).toV2))
+ Assertions.assertEquals(builder.build(In("c19", Array(19,20,21,22,23,24,25,26)).toV2), "`c19` IN (19,20,21,22,23,24,25,26)")
+ Assertions.assertEquals(builder.build(In("c19", Array("19","20")).toV2), "`c19` IN ('19','20')")
+ val inException = Assertions.assertThrows(classOf[IllegalArgumentException], () => builder.build(In("c19", Array(19,20,21,22,23,24,25,26,27,28,29)).toV2))
+ Assertions.assertEquals(inException.getMessage, "exceeding limit of IN values: actual size 12, limit size 10")
+ val exception = Assertions.assertThrows(classOf[IllegalArgumentException], () => builder.build(ExpressionUtil.buildCoalesceFilter()))
+ Assertions.assertEquals(exception.getMessage, "Unsupported expression: COALESCE")
}
+ @Test
+ def buildOptTest() : Unit = {
+
+ val builder = new V2ExpressionBuilder(10)
+ Assertions.assertEquals(builder.buildOpt(EqualTo("c0", 1).toV2), Some("`c0` = 1"))
+ Assertions.assertEquals(builder.buildOpt(Not(EqualTo("c1", 2)).toV2), Some("`c1` != 2"))
+ Assertions.assertEquals(builder.buildOpt(GreaterThan("c2", 3.4).toV2), Some("`c2` > 3.4"))
+ Assertions.assertEquals(builder.buildOpt(ExpressionUtil.buildCoalesceFilter()), None)
+ }
+
}
diff --git a/spark-doris-connector/spark-doris-connector-spark-3.5/src/test/scala/org/apache/spark/sql/ExpressionUtil.scala b/spark-doris-connector/spark-doris-connector-spark-3.5/src/test/scala/org/apache/spark/sql/ExpressionUtil.scala
new file mode 100644
index 0000000..2da931a
--- /dev/null
+++ b/spark-doris-connector/spark-doris-connector-spark-3.5/src/test/scala/org/apache/spark/sql/ExpressionUtil.scala
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, GeneralScalarExpression, LiteralValue}
+import org.apache.spark.sql.types.StringType
+
+object ExpressionUtil {
+
+ def buildCoalesceFilter(): Expression = {
+ val gse = new GeneralScalarExpression("COALESCE", Array(FieldReference(Seq("A4")), LiteralValue("null", StringType)))
+ new Predicate("=", Array(gse, LiteralValue("1", StringType)))
+ }
+}