perf(spark): Write handle perf fix for prepending meta fields (#13860)


---------

Co-authored-by: danny0405 <yuzhao.cyz@gmail.com>
diff --git a/hudi-common/src/main/java/org/apache/hudi/avro/JoinedGenericRecord.java b/hudi-common/src/main/java/org/apache/hudi/avro/JoinedGenericRecord.java
new file mode 100644
index 0000000..9abbbe3
--- /dev/null
+++ b/hudi-common/src/main/java/org/apache/hudi/avro/JoinedGenericRecord.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi.avro;
+
+import org.apache.hudi.common.model.HoodieRecord;
+
+import org.apache.avro.Schema;
+import org.apache.avro.generic.GenericRecord;
+
+/**
+ * Impl of {@link GenericRecord} to abstract meta fields and an actual data records of type GenericRecord.
+ */
+public class JoinedGenericRecord implements GenericRecord {
+  private final GenericRecord dataRecord;
+  private final Object[] metaFields;
+  private final Schema schema;
+
+  public JoinedGenericRecord(GenericRecord dataRecord, int metaFieldsSize, Schema schema) {
+    this.dataRecord = dataRecord;
+    this.metaFields = new Object[metaFieldsSize];
+    this.schema = schema;
+  }
+
+  @Override
+  public void put(String key, Object v) {
+    Integer metaFieldPos = getMetaFieldPos(key);
+    if (metaFieldPos != null) {
+      metaFields[metaFieldPos] = v;
+    } else {
+      dataRecord.put(key, v);
+    }
+  }
+
+  @Override
+  public Object get(String key) {
+    Integer metaFieldPos = getMetaFieldPos(key);
+    if (metaFieldPos != null) {
+      return metaFields[metaFieldPos];
+    } else {
+      return dataRecord.get(key);
+    }
+  }
+
+  @Override
+  public void put(int i, Object v) {
+    if (i < metaFields.length) {
+      metaFields[i] = v;
+    } else {
+      dataRecord.put(i - metaFields.length, v);
+    }
+  }
+
+  @Override
+  public Object get(int i) {
+    if (i < metaFields.length) {
+      return metaFields[i];
+    } else {
+      return dataRecord.get(i - metaFields.length);
+    }
+  }
+
+  private Integer getMetaFieldPos(String fieldName) {
+    Integer pos = HoodieRecord.HOODIE_META_COLUMNS_NAME_TO_POS.get(fieldName);
+    if (pos == null && fieldName.equals(HoodieRecord.OPERATION_METADATA_FIELD)) {
+      return HoodieRecord.HOODIE_META_COLUMNS_NAME_TO_POS.size();
+    }
+    return pos;
+  }
+
+  @Override
+  public Schema getSchema() {
+    return schema;
+  }
+}
diff --git a/hudi-common/src/main/java/org/apache/hudi/common/model/HoodieAvroIndexedRecord.java b/hudi-common/src/main/java/org/apache/hudi/common/model/HoodieAvroIndexedRecord.java
index 3cdb89e..4787f02 100644
--- a/hudi-common/src/main/java/org/apache/hudi/common/model/HoodieAvroIndexedRecord.java
+++ b/hudi-common/src/main/java/org/apache/hudi/common/model/HoodieAvroIndexedRecord.java
@@ -20,6 +20,7 @@
 
 import org.apache.hudi.avro.AvroRecordContext;
 import org.apache.hudi.avro.HoodieAvroUtils;
+import org.apache.hudi.avro.JoinedGenericRecord;
 import org.apache.hudi.common.table.read.DeleteContext;
 import org.apache.hudi.common.util.ConfigUtils;
 import org.apache.hudi.common.util.Option;
@@ -197,7 +198,9 @@
   @Override
   public HoodieRecord prependMetaFields(Schema recordSchema, Schema targetSchema, MetadataValues metadataValues, Properties props) {
     decodeRecord(recordSchema);
-    GenericRecord newAvroRecord = HoodieAvroUtils.rewriteRecordWithNewSchema(data, targetSchema);
+    GenericRecord genericRecord = (GenericRecord) data;
+    int metaFieldSize = targetSchema.getFields().size() - genericRecord.getSchema().getFields().size();
+    GenericRecord newAvroRecord = metaFieldSize == 0 ? genericRecord : new JoinedGenericRecord(genericRecord, metaFieldSize, targetSchema);
     updateMetadataValuesInternal(newAvroRecord, metadataValues);
     HoodieAvroIndexedRecord newRecord = new HoodieAvroIndexedRecord(key, newAvroRecord, operation, metaData, orderingValue);
     newRecord.setNewLocation(this.newLocation);
diff --git a/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java b/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java
index 6e83b6c..4d5c3b8 100644
--- a/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java
+++ b/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java
@@ -125,6 +125,17 @@
       + "{\"name\": \"pii_col\", \"type\": \"string\", \"column_category\": \"user_profile\"}], "
       + "\"custom_schema_property\": \"custom_schema_property_value\"}";
 
+  private static final String EXAMPLE_SCHEMA_WITH_META_FIELDS = "{\"type\": \"record\",\"name\": \"testrec\",\"fields\": [ "
+      + "{\"name\": \"_hoodie_commit_time\",\"type\": \"string\"},"
+      + "{\"name\": \"_hoodie_commit_seqno\",\"type\": \"string\"},"
+      + "{\"name\": \"_hoodie_record_key\",\"type\": \"string\"},"
+      + "{\"name\": \"_hoodie_partition_path\",\"type\": \"string\"},"
+      + "{\"name\": \"_hoodie_file_name\",\"type\": \"string\"},"
+      + "{\"name\": \"timestamp\",\"type\": \"double\"},"
+      + "{\"name\": \"_row_key\", \"type\": \"string\"},"
+      + "{\"name\": \"non_pii_col\", \"type\": \"string\"},"
+      + "{\"name\": \"pii_col\", \"type\": \"string\", \"column_category\": \"user_profile\"}]}";
+
   private static final int NUM_FIELDS_IN_EXAMPLE_SCHEMA = 4;
 
   private static final String SCHEMA_WITH_METADATA_FIELD = "{\"type\": \"record\",\"name\": \"testrec2\",\"fields\": [ "
@@ -292,6 +303,46 @@
   }
 
   @Test
+  public void testJoinedGenericRecord() {
+    GenericRecord rec = new GenericData.Record(new Schema.Parser().parse(EXAMPLE_SCHEMA));
+    rec.put("_row_key", "key1");
+    rec.put("non_pii_col", "val1");
+    rec.put("pii_col", "val2");
+    rec.put("timestamp", 3.5);
+
+    GenericRecord rec1 = new JoinedGenericRecord(rec, 5, new Schema.Parser().parse(EXAMPLE_SCHEMA_WITH_META_FIELDS));
+    assertNull(rec1.get("_hoodie_commit_time"));
+    assertNull(rec1.get("_hoodie_record_key"));
+
+    assertEquals(rec.get("_row_key"), rec1.get("_row_key"));
+    assertEquals(rec.get("_row_key"), rec1.get(6));
+    assertEquals(rec.get("non_pii_col"), rec1.get("non_pii_col"));
+    assertEquals(rec.get("non_pii_col"), rec1.get(7));
+    assertEquals(rec.get("pii_col"), rec1.get("pii_col"));
+    assertEquals(rec.get("pii_col"), rec1.get(8));
+    assertEquals(rec.get("timestamp"), rec1.get("timestamp"));
+    assertEquals(rec.get("timestamp"), rec1.get(5));
+
+    // lets add meta field values and validate
+    rec1.put(0, "commitTime1");
+    rec1.put(1, "commitSecNo1");
+    rec1.put(2, "recKey1");
+    rec1.put(3, "pPath1");
+    rec1.put(4, "fileName");
+
+    assertEquals("commitTime1", rec1.get(0));
+    assertEquals("commitTime1", rec1.get(HoodieRecord.COMMIT_TIME_METADATA_FIELD));
+    assertEquals("commitSecNo1", rec1.get(1));
+    assertEquals("commitSecNo1", rec1.get(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD));
+    assertEquals("recKey1", rec1.get(2));
+    assertEquals("recKey1", rec1.get(HoodieRecord.RECORD_KEY_METADATA_FIELD));
+    assertEquals("pPath1", rec1.get(3));
+    assertEquals("pPath1", rec1.get(HoodieRecord.PARTITION_PATH_METADATA_FIELD));
+    assertEquals("fileName", rec1.get(4));
+    assertEquals("fileName", rec1.get(HoodieRecord.FILENAME_METADATA_FIELD));
+  }
+
+  @Test
   public void testNonNullableFieldWithoutDefault() {
     GenericRecord rec = new GenericData.Record(new Schema.Parser().parse(EXAMPLE_SCHEMA));
     rec.put("_row_key", "key1");
diff --git a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/io/TestMergeHandle.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/io/TestMergeHandle.java
index e4c9ada..5cfe2e2 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/io/TestMergeHandle.java
+++ b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/io/TestMergeHandle.java
@@ -19,7 +19,7 @@
 package org.apache.hudi.io;
 
 import org.apache.hudi.avro.AvroSchemaUtils;
-import org.apache.hudi.avro.HoodieAvroUtils;
+import org.apache.hudi.avro.JoinedGenericRecord;
 import org.apache.hudi.client.SecondaryIndexStats;
 import org.apache.hudi.client.SparkRDDWriteClient;
 import org.apache.hudi.client.WriteStatus;
@@ -70,9 +70,7 @@
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
-import org.mockito.Answers;
-import org.mockito.MockedStatic;
-import org.mockito.Mockito;
+import org.mockito.MockedConstruction;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -91,13 +89,13 @@
 import static org.apache.hudi.common.testutils.HoodieTestDataGenerator.AVRO_SCHEMA;
 import static org.apache.hudi.common.testutils.HoodieTestDataGenerator.TRIP_EXAMPLE_SCHEMA;
 import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.AssertionsKt.assertNull;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mockStatic;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mockConstruction;
 
 /**
  * Unit tests {@link HoodieMergeHandle}.
@@ -236,30 +234,23 @@
     FileGroupReaderBasedMergeHandle fileGroupReaderBasedMergeHandle = new FileGroupReaderBasedMergeHandle(
         config, instantTime, table, updates.iterator(), partitionPath, fileId, new LocalTaskContextSupplier(),
         Option.empty());
-
+    List<WriteStatus> writeStatuses;
     String recordKeyForFailure = updates.get(5).getRecordKey();
-    try (MockedStatic<HoodieAvroUtils> mockedStatic = mockStatic(HoodieAvroUtils.class, Mockito.withSettings().defaultAnswer(Answers.CALLS_REAL_METHODS))) {
-      int position = AVRO_SCHEMA.getField("_row_key").pos();
-      mockedStatic.when(() -> HoodieAvroUtils.rewriteRecordWithNewSchema(any(), any())).thenAnswer(invocationOnMock -> {
-        IndexedRecord record = invocationOnMock.getArgument(0);
-        if (record.get(position).toString().equals(recordKeyForFailure)) {
-          throw new HoodieIOException("Simulated write failure for record key: " + recordKeyForFailure);
-        }
-        return HoodieAvroUtils.rewriteRecordWithNewSchema((IndexedRecord) invocationOnMock.getArgument(0), invocationOnMock.getArgument(1), Collections.emptyMap());
-      });
+
+    try (MockedConstruction<JoinedGenericRecord> mocked = mockConstruction(JoinedGenericRecord.class,
+        (mock, context) -> {
+          doThrow(new HoodieIOException("Simulated write failure for record key: " + recordKeyForFailure))
+              .when(mock).put(any(), any());
+      })) {
       fileGroupReaderBasedMergeHandle.doMerge();
     }
 
-    List<WriteStatus> writeStatuses = fileGroupReaderBasedMergeHandle.close();
+    writeStatuses = fileGroupReaderBasedMergeHandle.close();
     WriteStatus writeStatus = writeStatuses.get(0);
-    assertEquals(1, writeStatus.getErrors().size());
+    assertEquals(2, writeStatus.getErrors().size());
     // check that record and secondary index stats are non-empty
-    assertFalse(writeStatus.getWrittenRecordDelegates().isEmpty());
-    assertFalse(writeStatus.getIndexStats().getSecondaryIndexStats().values().stream().flatMap(Collection::stream).count() == 0L);
-
-    writeStatus.getWrittenRecordDelegates().forEach(recordDelegate -> assertNotEquals(recordKeyForFailure, recordDelegate.getRecordKey()));
-    writeStatus.getIndexStats().getSecondaryIndexStats().values().stream().flatMap(Collection::stream)
-        .forEach(secondaryIndexStats -> assertNotEquals(recordKeyForFailure, secondaryIndexStats.getRecordKey()));
+    assertTrue(writeStatus.getWrittenRecordDelegates().isEmpty());
+    assertTrue(writeStatus.getIndexStats().getSecondaryIndexStats().values().stream().flatMap(Collection::stream).count() == 0L);
 
     AtomicBoolean cdcRecordsFound = new AtomicBoolean(false);
     String cdcFilePath = metaClient.getBasePath().toString() + "/" + writeStatus.getStat().getCdcStats().keySet().stream().findFirst().get();
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/execution/benchmark/CreateHandleBenchmark.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/execution/benchmark/CreateHandleBenchmark.scala
new file mode 100644
index 0000000..dad0dee
--- /dev/null
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/execution/benchmark/CreateHandleBenchmark.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.hudi.AvroConversionUtils
+import org.apache.hudi.HoodieSparkUtils
+import org.apache.hudi.client.WriteStatus
+import org.apache.hudi.common.config.HoodieMetadataConfig
+import org.apache.hudi.common.data.HoodieData
+import org.apache.hudi.common.engine.{HoodieLocalEngineContext, LocalTaskContextSupplier}
+import org.apache.hudi.common.model.{DefaultHoodieRecordPayload, HoodieAvroIndexedRecord, HoodieKey, HoodieRecord}
+import org.apache.hudi.common.table.HoodieTableConfig
+import org.apache.hudi.common.table.marker.MarkerType
+import org.apache.hudi.common.util.HoodieRecordUtils
+import org.apache.hudi.config.HoodieWriteConfig
+import org.apache.hudi.io.HoodieCreateHandle
+import org.apache.hudi.keygen.constant.KeyGeneratorOptions
+import org.apache.hudi.storage.hadoop.HadoopStorageConfiguration
+import org.apache.hudi.table.HoodieSparkTable
+import org.apache.hudi.table.HoodieTable
+
+import org.apache.avro.generic.IndexedRecord
+import org.apache.spark.hudi.benchmark.{HoodieBenchmark, HoodieBenchmarkBase}
+import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
+import org.apache.spark.sql.types._
+
+import java.util.{Properties, UUID}
+import java.util.stream.Collectors
+
+import scala.util.Random
+
+object CreateHandleBenchmark extends HoodieBenchmarkBase {
+  protected val spark: SparkSession = getSparkSession
+
+  def getSparkSession: SparkSession = SparkSession
+    .builder()
+    .master("local[1]")
+    .config("spark.driver.memory", "8G")
+    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+    .appName(this.getClass.getCanonicalName)
+    .getOrCreate()
+
+  def getDataFrame(numbers: Int): DataFrame = {
+    val rand = new Random(42)
+    val schema = createRandomSchema(numCols = 100, maxDepth = 5)
+    val rows = (1 to numbers).map(_ => generateRow(schema, rand))
+    spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)
+  }
+
+  def createRandomSchema(numCols: Int, maxDepth: Int): StructType = {
+    val types = Seq("string", "long", "int", "array", "map", "struct")
+    val fields = (1 to numCols).map { i =>
+      val dataType = types((i - 1) % types.length)
+      val colName = s"col$i"
+
+      val fieldType = dataType match {
+        case "string" => StringType
+        case "long" => LongType
+        case "int" => IntegerType
+        case "array" => ArrayType(StringType, containsNull = false)
+        case "map" => MapType(StringType, IntegerType, valueContainsNull = false)
+        case "struct" => generateNestedStruct(maxDepth)
+      }
+
+      StructField(colName, fieldType, nullable = false)
+    }
+
+    StructType(StructField("key", StringType, nullable = false) +: fields)
+  }
+
+  def generateNestedStruct(depth: Int): StructType = {
+    if (depth <= 0) {
+      StructType(Seq(
+        StructField("leafStr", StringType, nullable = false),
+        StructField("leafInt", IntegerType, nullable = false)
+      ))
+    } else {
+      StructType(Seq(
+        StructField("nestedStr", StringType, nullable = false),
+        StructField("nestedInt", IntegerType, nullable = false),
+        StructField("nestedStruct", generateNestedStruct(depth - 1), nullable = false)
+      ))
+    }
+  }
+
+  def generateRow(schema: StructType, rand: Random): Row = {
+    val values = schema.fields.map {
+      case StructField("key", _, _, _) => java.util.UUID.randomUUID().toString
+      case StructField(_, StringType, _, _) => s"str_${rand.nextInt(100)}"
+      case StructField(_, LongType, _, _) => rand.nextLong()
+      case StructField(_, IntegerType, _, _) => rand.nextInt(100)
+      case StructField(_, ArrayType(_, _), _, _) => Seq.fill(3)(s"arr_${rand.nextInt(100)}")
+      case StructField(_, MapType(_, _, _), _, _) => Map("a" -> rand.nextInt(10), "b" -> rand.nextInt(10))
+      case StructField(_, s: StructType, _, _) => generateRow(s, rand)
+      case _ => throw new RuntimeException("Unsupported type")
+    }
+    Row.fromSeq(values)
+  }
+
+  private def createHandleBenchmark: Unit = {
+    val benchmark = new HoodieBenchmark(s"perf create handle for hoodie", 10000)
+    val df = getDataFrame(100000)
+    val avroSchema = AvroConversionUtils.convertStructTypeToAvroSchema(df.schema, "record", "my")
+    spark.sparkContext.getConf.registerAvroSchemas(avroSchema)
+
+    df.write.format("hudi").option(KeyGeneratorOptions.RECORDKEY_FIELD_NAME.key(), "key")
+      .option(HoodieMetadataConfig.ENABLE.key(), "false")
+      .option(HoodieTableConfig.NAME.key(), "tbl_name").mode(SaveMode.Overwrite).save("/tmp/sample_test_table")
+    val dummpProps = new Properties()
+    val avroRecords: java.util.List[HoodieRecord[_]] = HoodieSparkUtils.createRdd(df, "struct_name", "name_space",
+      Some(avroSchema)).mapPartitions(
+      it => {
+        it.map { genRec =>
+          val hoodieKey = new HoodieKey(genRec.get("key").toString, "")
+          HoodieRecordUtils.createHoodieRecord(genRec, 0L, hoodieKey, classOf[DefaultHoodieRecordPayload].getName, false)
+        }
+      }).toJavaRDD().collect().stream().map[HoodieRecord[_]](hoodieRec => {
+      hoodieRec.asInstanceOf[HoodieAvroIndexedRecord].toIndexedRecord(avroSchema, dummpProps)
+      hoodieRec
+    }).collect(Collectors.toList[HoodieRecord[_]])
+
+    benchmark.addCase("create handle perf benchmark") { _ =>
+      val props = new Properties()
+      props.setProperty(KeyGeneratorOptions.RECORDKEY_FIELD_NAME.key(), "key")
+      val writeConfig = HoodieWriteConfig.newBuilder().withPath("/tmp/sample_test_table").withPreCombineField("col1")
+        .withSchema(avroSchema.toString)
+        .withMarkersType(MarkerType.DIRECT.name())
+        .withMetadataConfig(HoodieMetadataConfig.newBuilder().enable(false).build())
+        .withProps(props).build()
+
+      val engineContext = new HoodieLocalEngineContext(new HadoopStorageConfiguration(spark.sparkContext.hadoopConfiguration))
+      val hoodieTable: HoodieTable[_, HoodieData[HoodieRecord[_]], HoodieData[HoodieKey], HoodieData[WriteStatus]] =
+        HoodieSparkTable.create(writeConfig, engineContext).asInstanceOf[HoodieTable[_, HoodieData[HoodieRecord[_]], HoodieData[HoodieKey], HoodieData[WriteStatus]]]
+      val createHandle = new HoodieCreateHandle(writeConfig, "000000001", hoodieTable, "", UUID.randomUUID().toString, new LocalTaskContextSupplier())
+      avroRecords.forEach(record => {
+        val newAvroRec = new HoodieAvroIndexedRecord(record.getKey, record.getData.asInstanceOf[IndexedRecord], 0L, record.getOperation)
+        createHandle.write(newAvroRec, avroSchema, writeConfig.getProps)
+      })
+      createHandle.close()
+    }
+    benchmark.run()
+  }
+
+  override def afterAll(): Unit = {
+    spark.stop()
+  }
+
+  override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+    createHandleBenchmark
+  }
+}