[HUDI-6320] Fix partition parsing in Spark file index for custom keygen (#9273)

diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
index 3767b65..a7e90b2 100644
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
@@ -79,7 +79,7 @@
     spark = spark,
     metaClient = metaClient,
     schemaSpec = schemaSpec,
-    configProperties = getConfigProperties(spark, options),
+    configProperties = getConfigProperties(spark, options, metaClient),
     queryPaths = HoodieFileIndex.getQueryPaths(options),
     specifiedQueryInstant = options.get(DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key).map(HoodieSqlCommonUtils.formatQueryInstant),
     fileStatusCache = fileStatusCache
@@ -324,7 +324,7 @@
     schema.fieldNames.filter { colName => refs.exists(r => resolver.apply(colName, r.name)) }
   }
 
-  def getConfigProperties(spark: SparkSession, options: Map[String, String]) = {
+  def getConfigProperties(spark: SparkSession, options: Map[String, String], metaClient: HoodieTableMetaClient) = {
     val sqlConf: SQLConf = spark.sessionState.conf
     val properties = TypedProperties.fromMap(options.filter(p => p._2 != null).asJava)
 
@@ -342,6 +342,16 @@
     if (listingModeOverride != null) {
       properties.setProperty(DataSourceReadOptions.FILE_INDEX_LISTING_MODE_OVERRIDE.key, listingModeOverride)
     }
+    val partitionColumns = metaClient.getTableConfig.getPartitionFields
+    if (partitionColumns.isPresent) {
+      // NOTE: Multiple partition fields could have non-encoded slashes in the partition value.
+      //       We might not be able to properly parse partition-values from the listed partition-paths.
+      //       Fallback to eager listing in this case.
+      if (partitionColumns.get().length > 1
+        && (listingModeOverride == null || DataSourceReadOptions.FILE_INDEX_LISTING_MODE_LAZY.equals(listingModeOverride))) {
+        properties.setProperty(DataSourceReadOptions.FILE_INDEX_LISTING_MODE_OVERRIDE.key, DataSourceReadOptions.FILE_INDEX_LISTING_MODE_EAGER)
+      }
+    }
 
     properties
   }
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala
index 35ef3e9..b3d9e56 100644
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/SparkHoodieTableFileIndex.scala
@@ -29,11 +29,9 @@
 import org.apache.hudi.common.table.{HoodieTableMetaClient, TableSchemaResolver}
 import org.apache.hudi.common.util.ValidationUtils.checkState
 import org.apache.hudi.config.HoodieBootstrapConfig.DATA_QUERIES_ONLY
-import org.apache.hudi.hadoop.CachingPath
-import org.apache.hudi.hadoop.CachingPath.createRelativePathUnsafe
 import org.apache.hudi.internal.schema.Types.RecordType
 import org.apache.hudi.internal.schema.utils.Conversions
-import org.apache.hudi.keygen.{StringPartitionPathFormatter, TimestampBasedAvroKeyGenerator, TimestampBasedKeyGenerator}
+import org.apache.hudi.keygen.{CustomAvroKeyGenerator, CustomKeyGenerator, StringPartitionPathFormatter, TimestampBasedAvroKeyGenerator, TimestampBasedKeyGenerator}
 import org.apache.hudi.util.JFunction
 import org.apache.spark.api.java.JavaSparkContext
 import org.apache.spark.internal.Logging
@@ -44,7 +42,6 @@
 import org.apache.spark.sql.execution.datasources.{FileStatusCache, NoopCache}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
 
 import java.util.Collections
 import javax.annotation.concurrent.NotThreadSafe
@@ -99,7 +96,7 @@
       AvroConversionUtils.convertAvroSchemaToStructType(schemaUtil.getTableAvroSchema)
     })
 
-  protected lazy val shouldFastBootstrap = configProperties.getBoolean(DATA_QUERIES_ONLY.key, false)
+  protected lazy val shouldFastBootstrap: Boolean = configProperties.getBoolean(DATA_QUERIES_ONLY.key, false)
 
   private lazy val sparkParsePartitionUtil = sparkAdapter.getSparkParsePartitionUtil
 
@@ -115,14 +112,16 @@
       // Note that key generator class name could be null
       val keyGeneratorClassName = tableConfig.getKeyGeneratorClassName
       if (classOf[TimestampBasedKeyGenerator].getName.equalsIgnoreCase(keyGeneratorClassName)
-        || classOf[TimestampBasedAvroKeyGenerator].getName.equalsIgnoreCase(keyGeneratorClassName)) {
+        || classOf[TimestampBasedAvroKeyGenerator].getName.equalsIgnoreCase(keyGeneratorClassName)
+        || classOf[CustomKeyGenerator].getName.equalsIgnoreCase(keyGeneratorClassName)
+        || classOf[CustomAvroKeyGenerator].getName.equalsIgnoreCase(keyGeneratorClassName)) {
         val partitionFields = partitionColumns.get().map(column => StructField(column, StringType))
         StructType(partitionFields)
       } else {
         val partitionFields = partitionColumns.get().filter(column => nameFieldMap.contains(column))
           .map(column => nameFieldMap.apply(column))
 
-        if (partitionFields.size != partitionColumns.get().size) {
+        if (partitionFields.length != partitionColumns.get().length) {
           val isBootstrapTable = tableConfig.getBootstrapBasePath.isPresent
           if (isBootstrapTable) {
             // For bootstrapped tables its possible the schema does not contain partition field when source table
diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/cdc/HoodieCDCRDD.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/cdc/HoodieCDCRDD.scala
index 839b028..521fb7f 100644
--- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/cdc/HoodieCDCRDD.scala
+++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/cdc/HoodieCDCRDD.scala
@@ -86,7 +86,7 @@
 
   private val cdcSupplementalLoggingMode = metaClient.getTableConfig.cdcSupplementalLoggingMode
 
-  private val props = HoodieFileIndex.getConfigProperties(spark, Map.empty)
+  private val props = HoodieFileIndex.getConfigProperties(spark, Map.empty, metaClient)
 
   protected val payloadProps: Properties = Option(metaClient.getTableConfig.getPreCombineField)
     .map { preCombineField =>
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
index ba5c2ed..157f4fe 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
@@ -38,7 +38,6 @@
 import org.apache.hudi.common.util.PartitionPathEncodeUtils
 import org.apache.hudi.common.util.StringUtils.isNullOrEmpty
 import org.apache.hudi.config.HoodieWriteConfig
-import org.apache.hudi.exception.HoodieException
 import org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator.TimestampType
 import org.apache.hudi.metadata.HoodieTableMetadata
 import org.apache.hudi.testutils.HoodieSparkClientTestBase
@@ -325,28 +324,21 @@
         EqualTo(attribute("dt"), literal("2021/03/01")),
         EqualTo(attribute("hh"), literal("10"))
       )
+      val partitionAndFilesNoPruning = fileIndex.listFiles(Seq(partitionFilter2), Seq.empty)
 
-      // NOTE: That if file-index is in lazy-listing mode and we can't parse partition values, there's no way
-      //       to recover from this since Spark by default have to inject partition values parsed from the partition paths.
-      if (listingModeOverride == DataSourceReadOptions.FILE_INDEX_LISTING_MODE_LAZY) {
-        assertThrows(classOf[HoodieException]) { fileIndex.listFiles(Seq(partitionFilter2), Seq.empty) }
-      } else {
-        val partitionAndFilesNoPruning = fileIndex.listFiles(Seq(partitionFilter2), Seq.empty)
+      assertEquals(1, partitionAndFilesNoPruning.size)
+      // The partition prune would not work for this case, so the partition value it
+      // returns is a InternalRow.empty.
+      assertTrue(partitionAndFilesNoPruning.forall(_.values.numFields == 0))
+      // The returned file size should equal to the whole file size in all the partition paths.
+      assertEquals(getFileCountInPartitionPaths("2021/03/01/10", "2021/03/02/10"),
+        partitionAndFilesNoPruning.flatMap(_.files).length)
 
-        assertEquals(1, partitionAndFilesNoPruning.size)
-        // The partition prune would not work for this case, so the partition value it
-        // returns is a InternalRow.empty.
-        assertTrue(partitionAndFilesNoPruning.forall(_.values.numFields == 0))
-        // The returned file size should equal to the whole file size in all the partition paths.
-        assertEquals(getFileCountInPartitionPaths("2021/03/01/10", "2021/03/02/10"),
-          partitionAndFilesNoPruning.flatMap(_.files).length)
+      val readDF = spark.read.format("hudi").options(readerOpts).load()
 
-        val readDF = spark.read.format("hudi").options(readerOpts).load()
-
-        assertEquals(10, readDF.count())
-        // There are 5 rows in the  dt = 2021/03/01 and hh = 10
-        assertEquals(5, readDF.filter("dt = '2021/03/01' and hh ='10'").count())
-      }
+      assertEquals(10, readDF.count())
+      // There are 5 rows in the  dt = 2021/03/01 and hh = 10
+      assertEquals(5, readDF.filter("dt = '2021/03/01' and hh ='10'").count())
     }
 
     {
@@ -429,7 +421,7 @@
     val partitionAndFilesAfterPrune = fileIndex.listFiles(Seq(partitionFilters), Seq.empty)
     assertEquals(1, partitionAndFilesAfterPrune.size)
 
-    assertEquals(fileIndex.areAllPartitionPathsCached(), !complexExpressionPushDown)
+    assertTrue(fileIndex.areAllPartitionPathsCached())
 
     val PartitionDirectory(partitionActualValues, filesAfterPrune) = partitionAndFilesAfterPrune.head
     val partitionExpectValues = Seq("default", "2021-03-01", "5", "CN")
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
index bfbf425..ad443ff 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala
@@ -62,6 +62,7 @@
 import java.util.function.Consumer
 import scala.collection.JavaConversions._
 import scala.collection.JavaConverters._
+import scala.util.matching.Regex
 
 
 /**
@@ -886,8 +887,8 @@
   }
 
   @ParameterizedTest
-  @EnumSource(value = classOf[HoodieRecordType], names = Array("SPARK"))
-  def testSparkPartitionByWithCustomKeyGenerator(recordType: HoodieRecordType): Unit = {
+  @EnumSource(value = classOf[HoodieRecordType], names = Array("AVRO", "SPARK"))
+  def testSparkPartitionByWithCustomKeyGeneratorWithGlobbing(recordType: HoodieRecordType): Unit = {
     val (writeOpts, readOpts) = getWriterReaderOptsLessPartitionPath(recordType)
 
     // Without fieldType, the default is SIMPLE
@@ -942,6 +943,70 @@
     }
   }
 
+  @ParameterizedTest
+  @EnumSource(value = classOf[HoodieRecordType], names = Array("AVRO", "SPARK"))
+  def testSparkPartitionByWithCustomKeyGenerator(recordType: HoodieRecordType): Unit = {
+    val (writeOpts, readOpts) = getWriterReaderOptsLessPartitionPath(recordType)
+    // Specify fieldType as TIMESTAMP of type EPOCHMILLISECONDS and output date format as yyyy/MM/dd
+    var writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName, writeOpts)
+    writer.partitionBy("current_ts:TIMESTAMP")
+      .option(TIMESTAMP_TYPE_FIELD.key, "EPOCHMILLISECONDS")
+      .option(TIMESTAMP_OUTPUT_DATE_FORMAT.key, "yyyy/MM/dd")
+      .mode(SaveMode.Overwrite)
+      .save(basePath)
+    var recordsReadDF = spark.read.format("hudi")
+      .options(readOpts)
+      .load(basePath)
+    val udf_date_format = udf((data: Long) => new DateTime(data).toString(DateTimeFormat.forPattern("yyyy/MM/dd")))
+
+    assertEquals(0L, recordsReadDF.filter(col("_hoodie_partition_path") =!= udf_date_format(col("current_ts"))).count())
+
+    // Mixed fieldType with TIMESTAMP of type EPOCHMILLISECONDS and output date format as yyyy/MM/dd
+    writer = getDataFrameWriter(classOf[CustomKeyGenerator].getName, writeOpts)
+    writer.partitionBy("driver", "rider:SIMPLE", "current_ts:TIMESTAMP")
+      .option(TIMESTAMP_TYPE_FIELD.key, "EPOCHMILLISECONDS")
+      .option(TIMESTAMP_OUTPUT_DATE_FORMAT.key, "yyyy/MM/dd")
+      .mode(SaveMode.Overwrite)
+      .save(basePath)
+    recordsReadDF = spark.read.format("hudi")
+      .options(readOpts)
+      .load(basePath)
+    assertTrue(recordsReadDF.filter(col("_hoodie_partition_path") =!=
+      concat(col("driver"), lit("/"), col("rider"), lit("/"), udf_date_format(col("current_ts")))).count() == 0)
+  }
+
+  @Test
+  def testPartitionPruningForTimestampBasedKeyGenerator(): Unit = {
+    val (writeOpts, readOpts) = getWriterReaderOptsLessPartitionPath(HoodieRecordType.AVRO, enableFileIndex = true)
+    val writer = getDataFrameWriter(classOf[TimestampBasedKeyGenerator].getName, writeOpts)
+    writer.partitionBy("current_ts")
+      .option(TIMESTAMP_TYPE_FIELD.key, "EPOCHMILLISECONDS")
+      .option(TIMESTAMP_OUTPUT_DATE_FORMAT.key, "yyyy/MM/dd")
+      .mode(SaveMode.Overwrite)
+      .save(basePath)
+
+    val snapshotQueryRes = spark.read.format("hudi")
+      .options(readOpts)
+      .load(basePath)
+      .where("current_ts > '1970/01/16'")
+    assertTrue(checkPartitionFilters(snapshotQueryRes.queryExecution.executedPlan.toString, "current_ts.* > 1970/01/16"))
+  }
+
+  def checkPartitionFilters(sparkPlan: String, partitionFilter: String): Boolean = {
+    val partitionFilterPattern: Regex = """PartitionFilters: \[(.*?)\]""".r
+    val tsPattern: Regex = (partitionFilter).r
+
+    val partitionFilterMatch = partitionFilterPattern.findFirstMatchIn(sparkPlan)
+
+    partitionFilterMatch match {
+      case Some(m) =>
+        val filters = m.group(1)
+        tsPattern.findFirstIn(filters).isDefined
+      case None =>
+        false
+    }
+  }
+
   @Test
   def testSparkPartitionByWithSimpleKeyGenerator() {
     val (writeOpts, readOpts) = getWriterReaderOptsLessPartitionPath(HoodieRecordType.AVRO)