[#808] improvement(spark): Verify the number of written records to ensure data correctness (#1558)

### What changes were proposed in this pull request?

Verify the number of written records to enhance data accuracy.
Make sure all data records are sent by clients.
Make sure bugs like https://github.com/apache/incubator-uniffle/pull/714 will never be introduced into the code.

### Why are the changes needed?

A follow-up PR for https://github.com/apache/incubator-uniffle/pull/848.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing UTs.
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 9450c0f..b339984 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -65,6 +65,8 @@
   private AtomicLong usedBytes = new AtomicLong(0);
   // bytes of shuffle data which is in send list
   private AtomicLong inSendListBytes = new AtomicLong(0);
+  /** An atomic counter used to keep track of the number of records */
+  private AtomicLong recordCounter = new AtomicLong(0);
   // it's part of blockId
   private Map<Integer, Integer> partitionToSeqNo = Maps.newHashMap();
   private long askExecutorMemory;
@@ -236,6 +238,7 @@
     if (wb.getMemoryUsed() > bufferSize) {
       List<ShuffleBlockInfo> sentBlocks = new ArrayList<>(1);
       sentBlocks.add(createShuffleBlock(partitionId, wb));
+      recordCounter.addAndGet(wb.getRecordCount());
       copyTime += wb.getCopyTime();
       buffers.remove(partitionId);
       if (LOG.isDebugEnabled()) {
@@ -298,6 +301,7 @@
       dataSize += wb.getDataLength();
       memoryUsed += wb.getMemoryUsed();
       result.add(createShuffleBlock(entry.getKey(), wb));
+      recordCounter.addAndGet(wb.getRecordCount());
       iterator.remove();
       copyTime += wb.getCopyTime();
     }
@@ -509,6 +513,10 @@
     return inSendListBytes.get();
   }
 
+  protected long getRecordCount() {
+    return recordCounter.get();
+  }
+
   public void freeAllocatedMemory(long freeMemory) {
     freeMemory(freeMemory);
     allocatedBytes.addAndGet(-freeMemory);
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
index 1641da5..ac6ac9e 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriterBuffer.java
@@ -33,6 +33,7 @@
   private List<WrappedBuffer> buffers = Lists.newArrayList();
   private int dataLength = 0;
   private int memoryUsed = 0;
+  private long recordCount = 0;
 
   public WriterBuffer(int bufferSize) {
     this.bufferSize = bufferSize;
@@ -66,6 +67,7 @@
 
     nextOffset += length;
     dataLength += length;
+    recordCount++;
   }
 
   public boolean askForMemory(long length) {
@@ -98,6 +100,10 @@
     return memoryUsed;
   }
 
+  public long getRecordCount() {
+    return recordCount;
+  }
+
   private static final class WrappedBuffer {
 
     byte[] buffer;
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index b5428bd..9e64b2f 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -247,7 +247,9 @@
 
   private void writeImpl(Iterator<Product2<K, V>> records) {
     List<ShuffleBlockInfo> shuffleBlockInfos;
+    long recordCount = 0;
     while (records.hasNext()) {
+      recordCount++;
       Product2<K, V> record = records.next();
       int partition = getPartition(record._1());
       if (shuffleDependency.mapSideCombine()) {
@@ -264,6 +266,7 @@
     shuffleBlockInfos = bufferManager.clear();
     processShuffleBlockInfos(shuffleBlockInfos);
     long s = System.currentTimeMillis();
+    checkSentRecordCount(recordCount);
     checkBlockSendResult(blockIds);
     final long checkDuration = System.currentTimeMillis() - s;
     long commitDuration = 0;
@@ -291,6 +294,16 @@
             + bufferManager.getManagerCostInfo());
   }
 
+  private void checkSentRecordCount(long recordCount) {
+    if (recordCount != bufferManager.getRecordCount()) {
+      String errorMsg =
+          "Potential record loss may have occurred while preparing to send blocks for task["
+              + taskId
+              + "]";
+      throw new RssSendFailedException(errorMsg);
+    }
+  }
+
   /**
    * ShuffleBlock will be added to queue and send to shuffle server maintenance the following
    * information: 1. add blockId to set, check if it is send later 2. update shuffle server info,
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index e7340b8..2fc0340 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -262,7 +262,9 @@
     if (isCombine) {
       createCombiner = shuffleDependency.aggregator().get().createCombiner();
     }
+    long recordCount = 0;
     while (records.hasNext()) {
+      recordCount++;
       // Task should fast fail when sending data failed
       checkIfBlocksFailed();
 
@@ -285,6 +287,7 @@
       processShuffleBlockInfos(shuffleBlockInfos);
     }
     long checkStartTs = System.currentTimeMillis();
+    checkSentRecordCount(recordCount);
     checkBlockSendResult(blockIds);
     long commitStartTs = System.currentTimeMillis();
     long checkDuration = commitStartTs - checkStartTs;
@@ -310,6 +313,16 @@
             + bufferManager.getManagerCostInfo());
   }
 
+  private void checkSentRecordCount(long recordCount) {
+    if (recordCount != bufferManager.getRecordCount()) {
+      String errorMsg =
+          "Potential record loss may have occurred while preparing to send blocks for task["
+              + taskId
+              + "]";
+      throw new RssSendFailedException(errorMsg);
+    }
+  }
+
   // only push-based shuffle use this interface, but rss won't be used when push-based shuffle is
   // enabled.
   public long[] getPartitionLengths() {