[HUDI-2030] Add metadata cache to WriteProfile to reduce IO (#3090)

Keeps same number of instant metadata cache and refresh the cache on new
commits.
diff --git a/hudi-flink/src/main/java/org/apache/hudi/sink/partitioner/profile/WriteProfile.java b/hudi-flink/src/main/java/org/apache/hudi/sink/partitioner/profile/WriteProfile.java
index 71d0d83..3ccfb26 100644
--- a/hudi-flink/src/main/java/org/apache/hudi/sink/partitioner/profile/WriteProfile.java
+++ b/hudi-flink/src/main/java/org/apache/hudi/sink/partitioner/profile/WriteProfile.java
@@ -33,6 +33,7 @@
 import org.apache.hudi.table.action.commit.SmallFile;
 import org.apache.hudi.util.StreamerUtil;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.core.fs.Path;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
@@ -44,7 +45,10 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
  * Profiling of write statistics for {@link BucketAssigner},
@@ -101,6 +105,11 @@
    */
   private final Configuration hadoopConf;
 
+  /**
+   * Metadata cache to reduce IO of metadata files.
+   */
+  private final Map<String, HoodieCommitMetadata> metadataCache;
+
   public WriteProfile(HoodieWriteConfig config, HoodieFlinkEngineContext context) {
     this.config = config;
     this.basePath = new Path(config.getBasePath());
@@ -108,6 +117,7 @@
     this.recordsPerBucket = config.getCopyOnWriteInsertSplitSize();
     this.table = HoodieFlinkTable.create(config, context);
     this.hadoopConf = StreamerUtil.getHadoopConf();
+    this.metadataCache = new HashMap<>();
     // profile the record statistics on construction
     recordProfile();
   }
@@ -132,27 +142,28 @@
     long avgSize = config.getCopyOnWriteRecordSizeEstimate();
     long fileSizeThreshold = (long) (config.getRecordSizeEstimationThreshold() * config.getParquetSmallFileLimit());
     HoodieTimeline commitTimeline = table.getMetaClient().getCommitsTimeline().filterCompletedInstants();
-    try {
-      if (!commitTimeline.empty()) {
-        // Go over the reverse ordered commits to get a more recent estimate of average record size.
-        Iterator<HoodieInstant> instants = commitTimeline.getReverseOrderedInstants().iterator();
-        while (instants.hasNext()) {
-          HoodieInstant instant = instants.next();
-          HoodieCommitMetadata commitMetadata = HoodieCommitMetadata
-              .fromBytes(commitTimeline.getInstantDetails(instant).get(), HoodieCommitMetadata.class);
-          long totalBytesWritten = commitMetadata.fetchTotalBytesWritten();
-          long totalRecordsWritten = commitMetadata.fetchTotalRecordsWritten();
-          if (totalBytesWritten > fileSizeThreshold && totalRecordsWritten > 0) {
-            avgSize = (long) Math.ceil((1.0 * totalBytesWritten) / totalRecordsWritten);
-            break;
-          }
+    if (!commitTimeline.empty()) {
+      // Go over the reverse ordered commits to get a more recent estimate of average record size.
+      Iterator<HoodieInstant> instants = commitTimeline.getReverseOrderedInstants().iterator();
+      while (instants.hasNext()) {
+        HoodieInstant instant = instants.next();
+        final HoodieCommitMetadata commitMetadata =
+            this.metadataCache.computeIfAbsent(
+                instant.getTimestamp(),
+                k -> WriteProfiles.getCommitMetadataSafely(config.getTableName(), basePath, instant, commitTimeline)
+                    .orElse(null));
+        if (commitMetadata == null) {
+          continue;
+        }
+        long totalBytesWritten = commitMetadata.fetchTotalBytesWritten();
+        long totalRecordsWritten = commitMetadata.fetchTotalRecordsWritten();
+        if (totalBytesWritten > fileSizeThreshold && totalRecordsWritten > 0) {
+          avgSize = (long) Math.ceil((1.0 * totalBytesWritten) / totalRecordsWritten);
+          break;
         }
       }
-      LOG.info("AvgRecordSize => " + avgSize);
-    } catch (Throwable t) {
-      // make this fail safe.
-      LOG.error("Error trying to compute average bytes/record ", t);
     }
+    LOG.info("Refresh average bytes per record => " + avgSize);
     return avgSize;
   }
 
@@ -202,21 +213,37 @@
     return smallFileLocations;
   }
 
-  protected void initFSViewIfNecessary(HoodieTimeline commitTimeline) {
+  @VisibleForTesting
+  public void initFSViewIfNecessary(HoodieTimeline commitTimeline) {
     if (fsView == null) {
+      cleanMetadataCache(commitTimeline.getInstants());
       List<HoodieCommitMetadata> metadataList = commitTimeline.getInstants()
-          .map(instant -> WriteProfiles.getCommitMetadata(config.getTableName(), basePath, instant, commitTimeline))
+          .map(instant ->
+              this.metadataCache.computeIfAbsent(
+                  instant.getTimestamp(),
+                  k -> WriteProfiles.getCommitMetadataSafely(config.getTableName(), basePath, instant, commitTimeline)
+                      .orElse(null)))
+          .filter(Objects::nonNull)
           .collect(Collectors.toList());
       FileStatus[] commitFiles = WriteProfiles.getWritePathsOfInstants(basePath, hadoopConf, metadataList);
       fsView = new HoodieTableFileSystemView(table.getMetaClient(), commitTimeline, commitFiles);
     }
   }
 
+  /**
+   * Remove the overdue metadata from the cache
+   * whose instant does not belong to the given instants {@code instants}.
+   */
+  private void cleanMetadataCache(Stream<HoodieInstant> instants) {
+    Set<String> timestampSet = instants.map(HoodieInstant::getTimestamp).collect(Collectors.toSet());
+    this.metadataCache.keySet().retainAll(timestampSet);
+  }
+
   private void recordProfile() {
     this.avgSize = averageBytesPerRecord();
     if (config.shouldAllowMultiWriteOnSameInstant()) {
       this.recordsPerBucket = config.getParquetMaxFileSize() / avgSize;
-      LOG.info("InsertRecordsPerBucket => " + recordsPerBucket);
+      LOG.info("Refresh insert records per bucket => " + recordsPerBucket);
     }
   }
 
@@ -233,10 +260,15 @@
       // already reloaded
       return;
     }
+    this.table.getMetaClient().reloadActiveTimeline();
     recordProfile();
     this.fsView = null;
     this.smallFilesMap.clear();
-    this.table.getMetaClient().reloadActiveTimeline();
     this.reloadedCheckpointId = checkpointId;
   }
+
+  @VisibleForTesting
+  public Map<String, HoodieCommitMetadata> getMetadataCache() {
+    return this.metadataCache;
+  }
 }
diff --git a/hudi-flink/src/main/java/org/apache/hudi/sink/partitioner/profile/WriteProfiles.java b/hudi-flink/src/main/java/org/apache/hudi/sink/partitioner/profile/WriteProfiles.java
index 9a8b7d0..3679c8a 100644
--- a/hudi-flink/src/main/java/org/apache/hudi/sink/partitioner/profile/WriteProfiles.java
+++ b/hudi-flink/src/main/java/org/apache/hudi/sink/partitioner/profile/WriteProfiles.java
@@ -23,6 +23,7 @@
 import org.apache.hudi.common.model.HoodieCommitMetadata;
 import org.apache.hudi.common.table.timeline.HoodieInstant;
 import org.apache.hudi.common.table.timeline.HoodieTimeline;
+import org.apache.hudi.common.util.Option;
 import org.apache.hudi.config.HoodieWriteConfig;
 import org.apache.hudi.exception.HoodieException;
 
@@ -129,6 +130,32 @@
   }
 
   /**
+   * Returns the commit metadata of the given instant safely.
+   *
+   * @param tableName The table name
+   * @param basePath  The table base path
+   * @param instant   The hoodie instant
+   * @param timeline  The timeline
+   *
+   * @return the commit metadata or empty if any error occurs
+   */
+  public static Option<HoodieCommitMetadata> getCommitMetadataSafely(
+      String tableName,
+      Path basePath,
+      HoodieInstant instant,
+      HoodieTimeline timeline) {
+    byte[] data = timeline.getInstantDetails(instant).get();
+    try {
+      return Option.of(HoodieCommitMetadata.fromBytes(data, HoodieCommitMetadata.class));
+    } catch (IOException e) {
+      // make this fail safe.
+      LOG.error("Get write metadata for table {} with instant {} and path: {} error",
+          tableName, instant.getTimestamp(), basePath);
+      return Option.empty();
+    }
+  }
+
+  /**
    * Returns the commit metadata of the given instant.
    *
    * @param tableName The table name
diff --git a/hudi-flink/src/test/java/org/apache/hudi/sink/partitioner/TestBucketAssigner.java b/hudi-flink/src/test/java/org/apache/hudi/sink/partitioner/TestBucketAssigner.java
index 3efa444..1fc6e29 100644
--- a/hudi-flink/src/test/java/org/apache/hudi/sink/partitioner/TestBucketAssigner.java
+++ b/hudi-flink/src/test/java/org/apache/hudi/sink/partitioner/TestBucketAssigner.java
@@ -23,6 +23,7 @@
 import org.apache.hudi.common.config.SerializableConfiguration;
 import org.apache.hudi.common.model.HoodieRecordLocation;
 import org.apache.hudi.common.table.timeline.HoodieInstant;
+import org.apache.hudi.common.table.timeline.HoodieTimeline;
 import org.apache.hudi.common.util.Option;
 import org.apache.hudi.config.HoodieWriteConfig;
 import org.apache.hudi.sink.partitioner.profile.WriteProfile;
@@ -332,6 +333,30 @@
         smallFiles4.get(0).location.getInstantTime(), is(instant2));
   }
 
+  @Test
+  public void testWriteProfileMetadataCache() throws Exception {
+    WriteProfile writeProfile = new WriteProfile(writeConfig, context);
+    assertTrue(writeProfile.getMetadataCache().isEmpty(), "Empty table should no have any instant metadata");
+
+    HoodieTimeline emptyTimeline = writeProfile.getTable().getActiveTimeline();
+
+    // write 3 instants of data
+    for (int i = 0; i < 3; i++) {
+      TestData.writeData(TestData.DATA_SET_INSERT, conf);
+    }
+    writeProfile.reload(1);
+    assertThat("Metadata cache should have same number entries as timeline instants",
+        writeProfile.getMetadataCache().size(), is(3));
+
+    writeProfile.getSmallFiles("par1");
+    assertThat("The metadata should be reused",
+        writeProfile.getMetadataCache().size(), is(3));
+
+    writeProfile.reload(2);
+    writeProfile.initFSViewIfNecessary(emptyTimeline);
+    assertTrue(writeProfile.getMetadataCache().isEmpty(), "Metadata cache should be all cleaned");
+  }
+
   private static Option<String> getLastCompleteInstant(WriteProfile profile) {
     return profile.getTable().getMetaClient().getCommitsTimeline()
         .filterCompletedInstants().lastInstant().map(HoodieInstant::getTimestamp);