HBASE-27778 Incorrect ReplicationSourceWALReader.totalBufferUsed may … (#5158)

HBASE-27778 Incorrect ReplicationSourceWALReader.totalBufferUsed may cause replication hang up
Signed-off-by: Duo Zhang <zhangduo@apache.org>
diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/replication/regionserver/ReplicationSourceWALReader.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/replication/regionserver/ReplicationSourceWALReader.java
index 4e1d76a..d52ed86 100644
--- a/hbase-server/src/main/java/org/apache/hadoop/hbase/replication/regionserver/ReplicationSourceWALReader.java
+++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/replication/regionserver/ReplicationSourceWALReader.java
@@ -140,11 +140,9 @@
   public void run() {
     int sleepMultiplier = 1;
     while (isReaderRunning()) { // we only loop back here if something fatal happened to our stream
-      WALEntryBatch batch = null;
       try (WALEntryStream entryStream = new WALEntryStream(logQueue, fs, conf, currentPosition,
         source.getWALFileLengthProvider(), source.getSourceMetrics(), walGroupId)) {
         while (isReaderRunning()) { // loop here to keep reusing stream while we can
-          batch = null;
           if (!source.isPeerEnabled()) {
             Threads.sleep(sleepForRetries);
             continue;
@@ -174,14 +172,25 @@
             continue;
           }
           // below are all for hasNext == YES
-          batch = createBatch(entryStream);
-          readWALEntries(entryStream, batch);
-          currentPosition = entryStream.getPosition();
-          // need to propagate the batch even it has no entries since it may carry the last
-          // sequence id information for serial replication.
-          LOG.debug("Read {} WAL entries eligible for replication", batch.getNbEntries());
-          entryBatchQueue.put(batch);
-          sleepMultiplier = 1;
+          WALEntryBatch batch = createBatch(entryStream);
+          boolean successAddToQueue = false;
+          try {
+            readWALEntries(entryStream, batch);
+            currentPosition = entryStream.getPosition();
+            // need to propagate the batch even it has no entries since it may carry the last
+            // sequence id information for serial replication.
+            LOG.debug("Read {} WAL entries eligible for replication", batch.getNbEntries());
+            entryBatchQueue.put(batch);
+            successAddToQueue = true;
+            sleepMultiplier = 1;
+          } finally {
+            if (!successAddToQueue) {
+              // batch is not put to ReplicationSourceWALReader#entryBatchQueue,so we should
+              // decrease ReplicationSourceWALReader.totalBufferUsed by the byte size which
+              // acquired in ReplicationSourceWALReader.acquireBufferQuota.
+              this.releaseBufferQuota(batch);
+            }
+          }
         }
       } catch (WALEntryFilterRetryableException e) {
         // here we have to recreate the WALEntryStream, as when filtering, we have already called
@@ -212,7 +221,7 @@
     long entrySizeExcludeBulkLoad = getEntrySizeExcludeBulkLoad(entry);
     batch.addEntry(entry, entrySize);
     updateBatchStats(batch, entry, entrySize);
-    boolean totalBufferTooLarge = acquireBufferQuota(entrySizeExcludeBulkLoad);
+    boolean totalBufferTooLarge = acquireBufferQuota(batch, entrySizeExcludeBulkLoad);
 
     // Stop if too many entries or too big
     return totalBufferTooLarge || batch.getHeapSize() >= replicationBatchSizeCapacity
@@ -430,13 +439,26 @@
    * @param size delta size for grown buffer
    * @return true if we should clear buffer and push all
    */
-  private boolean acquireBufferQuota(long size) {
+  private boolean acquireBufferQuota(WALEntryBatch walEntryBatch, long size) {
     long newBufferUsed = totalBufferUsed.addAndGet(size);
     // Record the new buffer usage
     this.source.getSourceManager().getGlobalMetrics().setWALReaderEditsBufferBytes(newBufferUsed);
+    walEntryBatch.incrementUsedBufferSize(size);
     return newBufferUsed >= totalBufferQuota;
   }
 
+  /**
+   * To release the buffer quota of {@link WALEntryBatch} which acquired by
+   * {@link ReplicationSourceWALReader#acquireBufferQuota}
+   */
+  private void releaseBufferQuota(WALEntryBatch walEntryBatch) {
+    long usedBufferSize = walEntryBatch.getUsedBufferSize();
+    if (usedBufferSize > 0) {
+      long newBufferUsed = totalBufferUsed.addAndGet(-usedBufferSize);
+      this.source.getSourceManager().getGlobalMetrics().setWALReaderEditsBufferBytes(newBufferUsed);
+    }
+  }
+
   /** Returns whether the reader thread is running */
   public boolean isReaderRunning() {
     return isReaderRunning && !isInterrupted();
diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/replication/regionserver/WALEntryBatch.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/replication/regionserver/WALEntryBatch.java
index b5ef0f9..32a149d 100644
--- a/hbase-server/src/main/java/org/apache/hadoop/hbase/replication/regionserver/WALEntryBatch.java
+++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/replication/regionserver/WALEntryBatch.java
@@ -52,6 +52,9 @@
   private Map<String, Long> lastSeqIds = new HashMap<>();
   // indicate that this is the end of the current file
   private boolean endOfFile;
+  // indicate the buffer size used, which is added to
+  // ReplicationSourceWALReader.totalBufferUsed
+  private long usedBufferSize;
 
   /**
    * @param lastWalPath Path of the WAL the last entry in this batch was read from
@@ -153,11 +156,19 @@
     lastSeqIds.put(region, sequenceId);
   }
 
+  public void incrementUsedBufferSize(long increment) {
+    usedBufferSize += increment;
+  }
+
+  public long getUsedBufferSize() {
+    return this.usedBufferSize;
+  }
+
   @Override
   public String toString() {
     return "WALEntryBatch [walEntries=" + walEntriesWithSize + ", lastWalPath=" + lastWalPath
       + ", lastWalPosition=" + lastWalPosition + ", nbRowKeys=" + nbRowKeys + ", nbHFiles="
       + nbHFiles + ", heapSize=" + heapSize + ", lastSeqIds=" + lastSeqIds + ", endOfFile="
-      + endOfFile + "]";
+      + endOfFile + ",usedBufferSize=" + usedBufferSize + "]";
   }
 }
diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/replication/regionserver/TestBasicWALEntryStream.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/replication/regionserver/TestBasicWALEntryStream.java
index efd7685..01f0659 100644
--- a/hbase-server/src/test/java/org/apache/hadoop/hbase/replication/regionserver/TestBasicWALEntryStream.java
+++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/replication/regionserver/TestBasicWALEntryStream.java
@@ -308,6 +308,14 @@
     when(source.isRecovered()).thenReturn(recovered);
     MetricsReplicationGlobalSourceSource globalMetrics =
       Mockito.mock(MetricsReplicationGlobalSourceSource.class);
+    final AtomicLong bufferUsedCounter = new AtomicLong(0);
+    Mockito.doAnswer((invocationOnMock) -> {
+      bufferUsedCounter.set(invocationOnMock.getArgument(0, Long.class));
+      return null;
+    }).when(globalMetrics).setWALReaderEditsBufferBytes(Mockito.anyLong());
+    when(globalMetrics.getWALReaderEditsBufferBytes())
+      .then(invocationOnMock -> bufferUsedCounter.get());
+
     when(mockSourceManager.getGlobalMetrics()).thenReturn(globalMetrics);
     return source;
   }
@@ -791,4 +799,80 @@
     Waiter.waitFor(localConf, 10000,
       (Waiter.Predicate<Exception>) () -> logQueue.getQueueSize(fakeWalGroupId) == 1);
   }
+
+  /**
+   * This test is for HBASE-27778, when {@link WALEntryFilter#filter} throws exception for some
+   * entries in {@link WALEntryBatch},{@link ReplicationSourceWALReader#totalBufferUsed} should be
+   * decreased because {@link WALEntryBatch} is not put to
+   * {@link ReplicationSourceWALReader#entryBatchQueue}.
+   */
+  @Test
+  public void testReplicationSourceWALReaderWithPartialWALEntryFailingFilter() throws Exception {
+    appendEntriesToLogAndSync(3);
+    // get ending position
+    long position;
+    try (WALEntryStream entryStream =
+      new WALEntryStream(logQueue, fs, CONF, 0, log, new MetricsSource("1"), fakeWalGroupId)) {
+      for (int i = 0; i < 3; i++) {
+        assertNotNull(next(entryStream));
+      }
+      position = entryStream.getPosition();
+    }
+
+    Path walPath = getQueue().peek();
+    int maxThrowExceptionCount = 3;
+
+    ReplicationSource source = mockReplicationSource(false, CONF);
+    when(source.isPeerEnabled()).thenReturn(true);
+    PartialWALEntryFailingWALEntryFilter walEntryFilter =
+      new PartialWALEntryFailingWALEntryFilter(maxThrowExceptionCount, 3);
+    ReplicationSourceWALReader reader =
+      new ReplicationSourceWALReader(fs, CONF, logQueue, 0, walEntryFilter, source, fakeWalGroupId);
+    reader.start();
+    WALEntryBatch entryBatch = reader.take();
+
+    assertNotNull(entryBatch);
+    assertEquals(3, entryBatch.getWalEntries().size());
+    long sum = entryBatch.getWalEntries().stream()
+      .mapToLong(ReplicationSourceWALReader::getEntrySizeExcludeBulkLoad).sum();
+    assertEquals(position, entryBatch.getLastWalPosition());
+    assertEquals(walPath, entryBatch.getLastWalPath());
+    assertEquals(3, entryBatch.getNbRowKeys());
+    assertEquals(sum, source.getSourceManager().getTotalBufferUsed().get());
+    assertEquals(sum, source.getSourceManager().getGlobalMetrics().getWALReaderEditsBufferBytes());
+    assertEquals(maxThrowExceptionCount, walEntryFilter.getThrowExceptionCount());
+    assertNull(reader.poll(10));
+  }
+
+  private static class PartialWALEntryFailingWALEntryFilter implements WALEntryFilter {
+    private int filteredWALEntryCount = -1;
+    private int walEntryCount = 0;
+    private int throwExceptionCount = -1;
+    private int maxThrowExceptionCount;
+
+    public PartialWALEntryFailingWALEntryFilter(int throwExceptionLimit, int walEntryCount) {
+      this.maxThrowExceptionCount = throwExceptionLimit;
+      this.walEntryCount = walEntryCount;
+    }
+
+    @Override
+    public Entry filter(Entry entry) {
+      filteredWALEntryCount++;
+      if (filteredWALEntryCount < walEntryCount - 1) {
+        return entry;
+      }
+
+      filteredWALEntryCount = -1;
+      throwExceptionCount++;
+      if (throwExceptionCount <= maxThrowExceptionCount - 1) {
+        throw new WALEntryFilterRetryableException("failing filter");
+      }
+      return entry;
+    }
+
+    public int getThrowExceptionCount() {
+      return throwExceptionCount;
+    }
+  }
+
 }