[HUDI-7569] [RLI] Fix wrong result generated by query (#10955)
Co-authored-by: Vinaykumar Bhat <vinay@onehouse.ai>
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/RecordLevelIndexSupport.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/RecordLevelIndexSupport.scala
index 764ce69..894405a 100644
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/RecordLevelIndexSupport.scala
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/RecordLevelIndexSupport.scala
@@ -162,7 +162,10 @@
case inQuery: In =>
var validINQuery = true
inQuery.value match {
- case _: AttributeReference =>
+ case attribute: AttributeReference =>
+ if (!attributeMatchesRecordKey(attribute.name)) {
+ validINQuery = false
+ }
case _ => validINQuery = false
}
var literals: List[String] = List.empty
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestRecordLevelIndexWithSQL.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestRecordLevelIndexWithSQL.scala
index 8e23596..97fdc1e 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestRecordLevelIndexWithSQL.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestRecordLevelIndexWithSQL.scala
@@ -26,7 +26,8 @@
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, Literal, Or}
import org.apache.spark.sql.types.StringType
import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
-import org.junit.jupiter.api.Tag
+import org.junit.jupiter.api.io.TempDir
+import org.junit.jupiter.api.{Tag, Test}
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
@@ -155,4 +156,36 @@
val readDf = spark.read.format("hudi").options(hudiOpts).load(basePath)
readDf.registerTempTable(sqlTempTable)
}
+
+ @Test
+ def testInFilterOnNonRecordKey(): Unit = {
+ var hudiOpts = commonOpts
+ hudiOpts = hudiOpts + (
+ DataSourceWriteOptions.TABLE_TYPE.key -> HoodieTableType.COPY_ON_WRITE.name(),
+ DataSourceReadOptions.ENABLE_DATA_SKIPPING.key -> "true")
+
+ val dummyTablePath = tempDir.resolve("dummy_table").toAbsolutePath.toString
+ spark.sql(
+ s"""
+ |create table dummy_table (
+ | record_key_col string,
+ | not_record_key_col string,
+ | partition_key_col string
+ |) using hudi
+ | options (
+ | primaryKey ='record_key_col',
+ | hoodie.metadata.enable = 'true',
+ | hoodie.metadata.record.index.enable = 'true',
+ | hoodie.datasource.write.recordkey.field = 'record_key_col',
+ | hoodie.enable.data.skipping = 'true'
+ | )
+ | partitioned by(partition_key_col)
+ | location '$dummyTablePath'
+ """.stripMargin)
+ spark.sql(s"insert into dummy_table values('row1', 'row2', 'p1')")
+ spark.sql(s"insert into dummy_table values('row2', 'row1', 'p2')")
+ spark.sql(s"insert into dummy_table values('row3', 'row1', 'p2')")
+
+ assertEquals(2, spark.read.format("hudi").options(hudiOpts).load(dummyTablePath).filter("not_record_key_col in ('row1', 'abc')").count())
+ }
}